You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.7 KiB
43 lines
1.7 KiB
import torchvision.transforms as transforms |
|
from transformers import ViTImageProcessor, ViTForImageClassification |
|
from PIL import Image |
|
import requests |
|
import torch |
|
from torchvision.datasets import CIFAR10 |
|
from torch.utils.data import DataLoader |
|
|
|
classes = [ |
|
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' |
|
] |
|
|
|
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg' |
|
# image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' |
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck10.png' |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
test_dataset = CIFAR10(root='./data', train=False, download=True) |
|
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) |
|
|
|
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') |
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') |
|
|
|
# Изменение количества классов на 10 (для CIFAR-10) |
|
model.classifier = torch.nn.Linear(model.classifier.in_features, 10) |
|
|
|
# Перенос модели на GPU, если доступно |
|
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
# model.to(device) |
|
# model.half().to(device) # Load in half precision |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
|
|
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
preds = torch.argmax(logits, dim=1) |
|
print(preds.cpu().numpy()) |
|
# model predicts one of the 1000 ImageNet classes |
|
predicted_class_idx = logits.argmax(-1).item() |
|
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
|
|