Модель распознования изображений из набора данных CIFAR10
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

59 lines
2.0 KiB

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification, ViTImageProcessor
from datasets import load_dataset
from tqdm import tqdm
# Load CIFAR-10 dataset
dataset = load_dataset('cifar10')
# Define transformations for the dataset
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
# Apply transformations to the dataset
def preprocess_function(examples):
examples['pixel_values'] = [transform(image.convert("RGB")) for image in examples['img']]
return examples
encoded_dataset = dataset.with_transform(preprocess_function)
# Create DataLoader for the test set
def collate_fn(batch):
pixel_values = torch.stack([item['pixel_values'] for item in batch])
labels = torch.tensor([item['label'] for item in batch])
return {'pixel_values': pixel_values, 'label': labels}
test_loader = DataLoader(encoded_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn)
# Load the pre-trained ViT model
model_name = 'google/vit-base-patch16-224'
image_processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in tqdm(test_loader):
images = batch['pixel_values'].to(device)
labels = batch['label'].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.logits, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy of the model on the CIFAR-10 test images: {accuracy:.2f}%')