from transformers import ViTFeatureExtractor, ViTForImageClassification from PIL import Image import requests from tensorflow.keras import datasets, layers, models import matplotlib.pyplot as plt # Load the CIFAR-10 dataset (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() # Normalize pixel values to be between 0 and 1 train_images, test_images = train_images / 255.0, test_images / 255.0 url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' image = Image.open(requests.get(url, stream=True).raw) feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10') model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10') inputs = feature_extractor(images=test_images[10], return_tensors="pt") outputs = model(**inputs) preds = outputs.logits.argmax(dim=1) classes = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] print(classes[test_labels[10][0]]) print(classes[preds[0]])