You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
2.0 KiB
59 lines
2.0 KiB
import torch |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
from transformers import ViTForImageClassification, ViTImageProcessor |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
|
|
# Load CIFAR-10 dataset |
|
dataset = load_dataset('cifar10') |
|
|
|
# Define transformations for the dataset |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), # Resize images to 224x224 |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) |
|
]) |
|
|
|
# Apply transformations to the dataset |
|
def preprocess_function(examples): |
|
examples['pixel_values'] = [transform(image.convert("RGB")) for image in examples['img']] |
|
return examples |
|
|
|
encoded_dataset = dataset.with_transform(preprocess_function) |
|
|
|
# Create DataLoader for the test set |
|
def collate_fn(batch): |
|
pixel_values = torch.stack([item['pixel_values'] for item in batch]) |
|
labels = torch.tensor([item['label'] for item in batch]) |
|
return {'pixel_values': pixel_values, 'label': labels} |
|
|
|
test_loader = DataLoader(encoded_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn) |
|
|
|
# Load the pre-trained ViT model |
|
model_name = 'google/vit-base-patch16-224' |
|
image_processor = ViTImageProcessor.from_pretrained(model_name) |
|
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True) |
|
|
|
# Move model to GPU if available |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
# Evaluate the model |
|
model.eval() |
|
correct = 0 |
|
total = 0 |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm(test_loader): |
|
images = batch['pixel_values'].to(device) |
|
labels = batch['label'].to(device) |
|
|
|
outputs = model(images) |
|
_, predicted = torch.max(outputs.logits, 1) |
|
|
|
total += labels.size(0) |
|
correct += (predicted == labels).sum().item() |
|
|
|
accuracy = 100 * correct / total |
|
print(f'Accuracy of the model on the CIFAR-10 test images: {accuracy:.2f}%')
|
|
|