You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
150 lines
5.1 KiB
150 lines
5.1 KiB
import torch |
|
import torchvision.transforms as transforms |
|
from torchvision.datasets import CIFAR10 |
|
from torch.utils.data import DataLoader |
|
from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor |
|
import matplotlib.pyplot as plt |
|
import torch.nn.functional as F |
|
from sklearn.metrics import accuracy_score, classification_report |
|
from torch.optim import AdamW |
|
from torch.nn import CrossEntropyLoss |
|
#import os |
|
import numpy as np |
|
|
|
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
|
|
|
# Загрузка предобученной модели ViT-16 |
|
# model = ViTForImageClassification.from_pretrained('WinKawaks/vit-small-patch16-224') #-in21k') |
|
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') |
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') |
|
|
|
# Изменение количества классов на 10 (для CIFAR-10) |
|
model.classifier = torch.nn.Linear(model.classifier.in_features, 10) |
|
|
|
# Перенос модели на GPU, если доступно |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
# model.to(device) |
|
model.half().to(device) # Load in half precision |
|
|
|
# Преобразования для данных |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), # ViT-16 ожидает изображения размером 224x224 |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
# Загрузка данных |
|
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform) |
|
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) |
|
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) |
|
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
|
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) |
|
|
|
""" |
|
# Дообучение модели |
|
# Выполняется 2 дня |
|
# Оптимизатор и функция потерь |
|
optimizer = AdamW(model.parameters(), lr=1e-4) |
|
criterion = CrossEntropyLoss() |
|
|
|
model.train() |
|
num_epochs = 10 |
|
|
|
for epoch in range(num_epochs): |
|
for images, labels in train_loader: |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
|
|
optimizer.zero_grad() |
|
outputs = model(images).logits |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}') |
|
|
|
# for epoch in range(num_epochs): |
|
# epoch_loss = 0.0 |
|
# for batch_idx, (images, labels) in enumerate(train_loader): |
|
# images = images.to(device) |
|
# labels = labels.to(device) |
|
# |
|
# optimizer.zero_grad() |
|
# outputs = model(images).logits |
|
# loss = criterion(outputs, labels) |
|
# loss.backward() |
|
# optimizer.step() |
|
# |
|
# epoch_loss += loss.item() |
|
# |
|
# # Печать прогресса каждые 10 батчей |
|
# if (batch_idx + 1) % 10 == 0: |
|
# print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item()}') |
|
# |
|
# # Печать средней потери за эпоху |
|
# avg_epoch_loss = epoch_loss / len(train_loader) |
|
# print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss}') |
|
""" |
|
|
|
# Функция для оценки модели |
|
def evaluate_model(model, data_loader, device): |
|
model.eval() |
|
all_preds = [] |
|
all_labels = [] |
|
|
|
with torch.no_grad(): |
|
for images, labels in data_loader: |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
|
|
outputs = model(images).logits |
|
preds = torch.argmax(outputs, dim=1) |
|
|
|
all_preds.extend(preds.cpu().numpy()) |
|
all_labels.extend(labels.cpu().numpy()) |
|
|
|
accuracy = accuracy_score(all_labels, all_preds) |
|
report = classification_report(all_labels, all_preds, target_names=train_dataset.classes) |
|
|
|
return accuracy, report |
|
|
|
# Оценка модели на тестовых данных |
|
accuracy, report = evaluate_model(model, test_loader, device) |
|
|
|
print(f'Test Accuracy: {accuracy:.4f}') |
|
print('Classification Report:') |
|
print(report) |
|
|
|
# Загрузка 25 тестовых изображений |
|
test_images = [] |
|
true_labels = [] |
|
predicted_labels = [] |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
for images, labels in test_loader: |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
|
|
outputs = model(images).logits |
|
preds = torch.argmax(outputs, dim=1) |
|
|
|
test_images.extend(images.cpu()) |
|
true_labels.extend(labels.cpu()) |
|
predicted_labels.extend(preds.cpu()) |
|
|
|
if len(test_images) >= 25: |
|
break |
|
|
|
fig, axes = plt.subplots(5, 5, figsize=(15, 15)) |
|
axes = axes.ravel() |
|
|
|
for i in range(5 * 5): |
|
axes[i].imshow(np.transpose(test_images[:25][i].cpu().numpy(), (1, 2, 0))) |
|
axes[i].set_title(f'True: {train_dataset.classes[true_labels[i]]}\nPred: {train_dataset.classes[predicted_labels[i]]}') |
|
axes[i].axis('off') |
|
|
|
plt.subplots_adjust(hspace=0.4) |
|
plt.show()
|
|
|