Модель распознования изображений из набора данных CIFAR10
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

150 lines
5.1 KiB

import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, classification_report
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
#import os
import numpy as np
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Загрузка предобученной модели ViT-16
# model = ViTForImageClassification.from_pretrained('WinKawaks/vit-small-patch16-224') #-in21k')
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# Изменение количества классов на 10 (для CIFAR-10)
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)
# Перенос модели на GPU, если доступно
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model.to(device)
model.half().to(device) # Load in half precision
# Преобразования для данных
transform = transforms.Compose([
transforms.Resize((224, 224)), # ViT-16 ожидает изображения размером 224x224
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Загрузка данных
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
"""
# Дообучение модели
# Выполняется 2 дня
# Оптимизатор и функция потерь
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()
model.train()
num_epochs = 10
for epoch in range(num_epochs):
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images).logits
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
# for epoch in range(num_epochs):
# epoch_loss = 0.0
# for batch_idx, (images, labels) in enumerate(train_loader):
# images = images.to(device)
# labels = labels.to(device)
#
# optimizer.zero_grad()
# outputs = model(images).logits
# loss = criterion(outputs, labels)
# loss.backward()
# optimizer.step()
#
# epoch_loss += loss.item()
#
# # Печать прогресса каждые 10 батчей
# if (batch_idx + 1) % 10 == 0:
# print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item()}')
#
# # Печать средней потери за эпоху
# avg_epoch_loss = epoch_loss / len(train_loader)
# print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss}')
"""
# Функция для оценки модели
def evaluate_model(model, data_loader, device):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in data_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images).logits
preds = torch.argmax(outputs, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
accuracy = accuracy_score(all_labels, all_preds)
report = classification_report(all_labels, all_preds, target_names=train_dataset.classes)
return accuracy, report
# Оценка модели на тестовых данных
accuracy, report = evaluate_model(model, test_loader, device)
print(f'Test Accuracy: {accuracy:.4f}')
print('Classification Report:')
print(report)
# Загрузка 25 тестовых изображений
test_images = []
true_labels = []
predicted_labels = []
model.eval()
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images).logits
preds = torch.argmax(outputs, dim=1)
test_images.extend(images.cpu())
true_labels.extend(labels.cpu())
predicted_labels.extend(preds.cpu())
if len(test_images) >= 25:
break
fig, axes = plt.subplots(5, 5, figsize=(15, 15))
axes = axes.ravel()
for i in range(5 * 5):
axes[i].imshow(np.transpose(test_images[:25][i].cpu().numpy(), (1, 2, 0)))
axes[i].set_title(f'True: {train_dataset.classes[true_labels[i]]}\nPred: {train_dataset.classes[predicted_labels[i]]}')
axes[i].axis('off')
plt.subplots_adjust(hspace=0.4)
plt.show()