You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
4.1 KiB
113 lines
4.1 KiB
import torch |
|
import torchvision.transforms as transforms |
|
from torchvision.datasets import CIFAR10 |
|
from torch.utils.data import DataLoader |
|
from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor |
|
import torch.nn.functional as F |
|
from sklearn.metrics import accuracy_score, classification_report |
|
from torch.optim import AdamW |
|
from torch.nn import CrossEntropyLoss |
|
#import os |
|
|
|
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
|
|
|
# Загрузка предобученной модели ViT-16 |
|
# model = ViTForImageClassification.from_pretrained('WinKawaks/vit-small-patch16-224') #-in21k') |
|
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') |
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') |
|
|
|
# Изменение количества классов на 10 (для CIFAR-10) |
|
model.classifier = torch.nn.Linear(model.classifier.in_features, 10) |
|
|
|
# Перенос модели на GPU, если доступно |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
# model.to(device) |
|
model.half().to(device) # Load in half precision |
|
|
|
# Преобразования для данных |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), # ViT-16 ожидает изображения размером 224x224 |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
# Загрузка данных |
|
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform) |
|
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) |
|
|
|
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) |
|
# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) |
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) |
|
|
|
# Оптимизатор и функция потерь |
|
optimizer = AdamW(model.parameters(), lr=1e-4) |
|
criterion = CrossEntropyLoss() |
|
|
|
# Дообучение модели |
|
model.train() |
|
num_epochs = 10 |
|
|
|
for epoch in range(num_epochs): |
|
for images, labels in train_loader: |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
|
|
optimizer.zero_grad() |
|
outputs = model(images).logits |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}') |
|
|
|
# for epoch in range(num_epochs): |
|
# epoch_loss = 0.0 |
|
# for batch_idx, (images, labels) in enumerate(train_loader): |
|
# images = images.to(device) |
|
# labels = labels.to(device) |
|
# |
|
# optimizer.zero_grad() |
|
# outputs = model(images).logits |
|
# loss = criterion(outputs, labels) |
|
# loss.backward() |
|
# optimizer.step() |
|
# |
|
# epoch_loss += loss.item() |
|
# |
|
# # Печать прогресса каждые 10 батчей |
|
# if (batch_idx + 1) % 10 == 0: |
|
# print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item()}') |
|
# |
|
# # Печать средней потери за эпоху |
|
# avg_epoch_loss = epoch_loss / len(train_loader) |
|
# print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss}') |
|
|
|
# Функция для оценки модели |
|
def evaluate_model(model, data_loader, device): |
|
model.eval() |
|
all_preds = [] |
|
all_labels = [] |
|
|
|
with torch.no_grad(): |
|
for images, labels in data_loader: |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
|
|
outputs = model(images).logits |
|
preds = torch.argmax(outputs, dim=1) |
|
|
|
all_preds.extend(preds.cpu().numpy()) |
|
all_labels.extend(labels.cpu().numpy()) |
|
|
|
accuracy = accuracy_score(all_labels, all_preds) |
|
report = classification_report(all_labels, all_preds, target_names=train_dataset.classes) |
|
|
|
return accuracy, report |
|
|
|
# Оценка модели на тестовых данных |
|
accuracy, report = evaluate_model(model, test_loader, device) |
|
|
|
print(f'Test Accuracy: {accuracy:.4f}') |
|
print('Classification Report:') |
|
print(report)
|
|
|