You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
26 lines
1.0 KiB
26 lines
1.0 KiB
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
from PIL import Image |
|
import requests |
|
from tensorflow.keras import datasets, layers, models |
|
import matplotlib.pyplot as plt |
|
|
|
# Load the CIFAR-10 dataset |
|
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() |
|
|
|
# Normalize pixel values to be between 0 and 1 |
|
train_images, test_images = train_images / 255.0, test_images / 255.0 |
|
|
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10') |
|
model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10') |
|
inputs = feature_extractor(images=test_images[10], return_tensors="pt") |
|
outputs = model(**inputs) |
|
preds = outputs.logits.argmax(dim=1) |
|
|
|
classes = [ |
|
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' |
|
] |
|
print(classes[test_labels[10][0]]) |
|
print(classes[preds[0]]) |
|
|
|
|