import torch import torchvision.transforms as transforms from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor import matplotlib.pyplot as plt import torch.nn.functional as F from sklearn.metrics import accuracy_score, classification_report from torch.optim import AdamW from torch.nn import CrossEntropyLoss #import os import numpy as np #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # Загрузка предобученной модели ViT-16 # model = ViTForImageClassification.from_pretrained('WinKawaks/vit-small-patch16-224') #-in21k') processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') # Изменение количества классов на 10 (для CIFAR-10) model.classifier = torch.nn.Linear(model.classifier.in_features, 10) # Перенос модели на GPU, если доступно device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # model.to(device) model.half().to(device) # Load in half precision # Преобразования для данных transform = transforms.Compose([ transforms.Resize((224, 224)), # ViT-16 ожидает изображения размером 224x224 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Загрузка данных train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform) test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) """ # Дообучение модели # Выполняется 2 дня # Оптимизатор и функция потерь optimizer = AdamW(model.parameters(), lr=1e-4) criterion = CrossEntropyLoss() model.train() num_epochs = 10 for epoch in range(num_epochs): for images, labels in train_loader: images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images).logits loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}') # for epoch in range(num_epochs): # epoch_loss = 0.0 # for batch_idx, (images, labels) in enumerate(train_loader): # images = images.to(device) # labels = labels.to(device) # # optimizer.zero_grad() # outputs = model(images).logits # loss = criterion(outputs, labels) # loss.backward() # optimizer.step() # # epoch_loss += loss.item() # # # Печать прогресса каждые 10 батчей # if (batch_idx + 1) % 10 == 0: # print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item()}') # # # Печать средней потери за эпоху # avg_epoch_loss = epoch_loss / len(train_loader) # print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss}') """ # Функция для оценки модели def evaluate_model(model, data_loader, device): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in data_loader: images = images.to(device) labels = labels.to(device) outputs = model(images).logits preds = torch.argmax(outputs, dim=1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) accuracy = accuracy_score(all_labels, all_preds) report = classification_report(all_labels, all_preds, target_names=train_dataset.classes) return accuracy, report # Оценка модели на тестовых данных accuracy, report = evaluate_model(model, test_loader, device) print(f'Test Accuracy: {accuracy:.4f}') print('Classification Report:') print(report) # Загрузка 25 тестовых изображений test_images = [] true_labels = [] predicted_labels = [] model.eval() with torch.no_grad(): for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = model(images).logits preds = torch.argmax(outputs, dim=1) test_images.extend(images.cpu()) true_labels.extend(labels.cpu()) predicted_labels.extend(preds.cpu()) if len(test_images) >= 25: break fig, axes = plt.subplots(5, 5, figsize=(15, 15)) axes = axes.ravel() for i in range(5 * 5): axes[i].imshow(np.transpose(test_images[:25][i].cpu().numpy(), (1, 2, 0))) axes[i].set_title(f'True: {train_dataset.classes[true_labels[i]]}\nPred: {train_dataset.classes[predicted_labels[i]]}') axes[i].axis('off') plt.subplots_adjust(hspace=0.4) plt.show()