import torchvision.transforms as transforms from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import requests import torch from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader classes = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] # url = 'http://images.cocodataset.org/val2017/000000039769.jpg' # image = Image.open(requests.get(url, stream=True).raw) url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck10.png' image = Image.open(requests.get(url, stream=True).raw) test_dataset = CIFAR10(root='./data', train=False, download=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') # Изменение количества классов на 10 (для CIFAR-10) model.classifier = torch.nn.Linear(model.classifier.in_features, 10) # Перенос модели на GPU, если доступно # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # model.to(device) # model.half().to(device) # Load in half precision inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] preds = torch.argmax(logits, dim=1) print(preds.cpu().numpy()) # model predicts one of the 1000 ImageNet classes predicted_class_idx = logits.argmax(-1).item() print("Predicted class:", model.config.id2label[predicted_class_idx])