import torch from torch.utils.data import DataLoader from torchvision import transforms from transformers import ViTForImageClassification, ViTImageProcessor from datasets import load_dataset from tqdm import tqdm # Load CIFAR-10 dataset dataset = load_dataset('cifar10') # Define transformations for the dataset transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize images to 224x224 transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) ]) # Apply transformations to the dataset def preprocess_function(examples): examples['pixel_values'] = [transform(image.convert("RGB")) for image in examples['img']] return examples encoded_dataset = dataset.with_transform(preprocess_function) # Create DataLoader for the test set def collate_fn(batch): pixel_values = torch.stack([item['pixel_values'] for item in batch]) labels = torch.tensor([item['label'] for item in batch]) return {'pixel_values': pixel_values, 'label': labels} test_loader = DataLoader(encoded_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn) # Load the pre-trained ViT model model_name = 'google/vit-base-patch16-224' image_processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True) # Move model to GPU if available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # Evaluate the model model.eval() correct = 0 total = 0 with torch.no_grad(): for batch in tqdm(test_loader): images = batch['pixel_values'].to(device) labels = batch['label'].to(device) outputs = model(images) _, predicted = torch.max(outputs.logits, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f'Accuracy of the model on the CIFAR-10 test images: {accuracy:.2f}%')