From 7c63dde1b90ab84c57bf87fa03f9c660b3cafed5 Mon Sep 17 00:00:00 2001 From: bogdan Date: Sun, 19 Jan 2025 20:18:09 +0300 Subject: [PATCH] =?UTF-8?q?=D0=A0=D0=B0=D0=B7=D0=BD=D1=8B=D0=B5=20=D0=BF?= =?UTF-8?q?=D0=BE=D0=BF=D1=8B=D1=82=D0=BA=D0=B8=20=D1=80=D0=B0=D0=B7=D0=BE?= =?UTF-8?q?=D0=B1=D1=80=D0=B0=D1=82=D1=8C=D1=81=D1=8F=20=D1=81=20ViT16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 7 +++++- ViT16.py | 47 ++++++++++++++++++++++++++++++++++----- ViT16_second.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++ ViT16_single.py | 43 +++++++++++++++++++++++++++++++++++ cifar10.py | 30 +++++++++++++++++++------ cifral10_2.py | 1 + nateraw.py | 26 ++++++++++++++++++++++ readme.ipynb | 29 +++++++++++++++++++----- 8 files changed, 223 insertions(+), 19 deletions(-) create mode 100644 ViT16_second.py create mode 100644 ViT16_single.py create mode 100644 nateraw.py diff --git a/README.md b/README.md index 0df39f2..254b203 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,9 @@ -### Dataset source +## Dataset source - https://www.cs.toronto.edu/%7Ekriz/cifar.html - https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz (c58f30108f718f92721af3b95e74349a) + +## Результат + +Модель на 3х сверточных слоях на тестовых данных: 71% (cifar10.py) +Модель google/vit-base-patch16-224: 10% (вероятнее всего не разаборался как корректно ее перенастроить на классификацию по 10 категориям) (ViT16.py) \ No newline at end of file diff --git a/ViT16.py b/ViT16.py index fbd6cbd..078b4b5 100644 --- a/ViT16.py +++ b/ViT16.py @@ -3,11 +3,13 @@ import torchvision.transforms as transforms from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor +import matplotlib.pyplot as plt import torch.nn.functional as F from sklearn.metrics import accuracy_score, classification_report from torch.optim import AdamW from torch.nn import CrossEntropyLoss #import os +import numpy as np #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' @@ -35,16 +37,18 @@ transform = transforms.Compose([ train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform) test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) -# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) -# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) -train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) -test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) +train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) +test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) +# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) +# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) +""" +# Дообучение модели +# Выполняется 2 дня # Оптимизатор и функция потерь optimizer = AdamW(model.parameters(), lr=1e-4) criterion = CrossEntropyLoss() -# Дообучение модели model.train() num_epochs = 10 @@ -82,6 +86,7 @@ for epoch in range(num_epochs): # # Печать средней потери за эпоху # avg_epoch_loss = epoch_loss / len(train_loader) # print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss}') +""" # Функция для оценки модели def evaluate_model(model, data_loader, device): @@ -111,3 +116,35 @@ accuracy, report = evaluate_model(model, test_loader, device) print(f'Test Accuracy: {accuracy:.4f}') print('Classification Report:') print(report) + +# Загрузка 25 тестовых изображений +test_images = [] +true_labels = [] +predicted_labels = [] + +model.eval() +with torch.no_grad(): + for images, labels in test_loader: + images = images.to(device) + labels = labels.to(device) + + outputs = model(images).logits + preds = torch.argmax(outputs, dim=1) + + test_images.extend(images.cpu()) + true_labels.extend(labels.cpu()) + predicted_labels.extend(preds.cpu()) + + if len(test_images) >= 25: + break + +fig, axes = plt.subplots(5, 5, figsize=(15, 15)) +axes = axes.ravel() + +for i in range(5 * 5): + axes[i].imshow(np.transpose(test_images[:25][i].cpu().numpy(), (1, 2, 0))) + axes[i].set_title(f'True: {train_dataset.classes[true_labels[i]]}\nPred: {train_dataset.classes[predicted_labels[i]]}') + axes[i].axis('off') + +plt.subplots_adjust(hspace=0.4) +plt.show() diff --git a/ViT16_second.py b/ViT16_second.py new file mode 100644 index 0000000..da0cc7f --- /dev/null +++ b/ViT16_second.py @@ -0,0 +1,59 @@ +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +from transformers import ViTForImageClassification, ViTImageProcessor +from datasets import load_dataset +from tqdm import tqdm + +# Load CIFAR-10 dataset +dataset = load_dataset('cifar10') + +# Define transformations for the dataset +transform = transforms.Compose([ + transforms.Resize((224, 224)), # Resize images to 224x224 + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) +]) + +# Apply transformations to the dataset +def preprocess_function(examples): + examples['pixel_values'] = [transform(image.convert("RGB")) for image in examples['img']] + return examples + +encoded_dataset = dataset.with_transform(preprocess_function) + +# Create DataLoader for the test set +def collate_fn(batch): + pixel_values = torch.stack([item['pixel_values'] for item in batch]) + labels = torch.tensor([item['label'] for item in batch]) + return {'pixel_values': pixel_values, 'label': labels} + +test_loader = DataLoader(encoded_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn) + +# Load the pre-trained ViT model +model_name = 'google/vit-base-patch16-224' +image_processor = ViTImageProcessor.from_pretrained(model_name) +model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True) + +# Move model to GPU if available +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model.to(device) + +# Evaluate the model +model.eval() +correct = 0 +total = 0 + +with torch.no_grad(): + for batch in tqdm(test_loader): + images = batch['pixel_values'].to(device) + labels = batch['label'].to(device) + + outputs = model(images) + _, predicted = torch.max(outputs.logits, 1) + + total += labels.size(0) + correct += (predicted == labels).sum().item() + +accuracy = 100 * correct / total +print(f'Accuracy of the model on the CIFAR-10 test images: {accuracy:.2f}%') diff --git a/ViT16_single.py b/ViT16_single.py new file mode 100644 index 0000000..9a1f2bf --- /dev/null +++ b/ViT16_single.py @@ -0,0 +1,43 @@ +import torchvision.transforms as transforms +from transformers import ViTImageProcessor, ViTForImageClassification +from PIL import Image +import requests +import torch +from torchvision.datasets import CIFAR10 +from torch.utils.data import DataLoader + +classes = [ + 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' +] + +# url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +# image = Image.open(requests.get(url, stream=True).raw) + +url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' +url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck10.png' +image = Image.open(requests.get(url, stream=True).raw) + +test_dataset = CIFAR10(root='./data', train=False, download=True) +test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) + +processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') +model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') + +# Изменение количества классов на 10 (для CIFAR-10) +model.classifier = torch.nn.Linear(model.classifier.in_features, 10) + +# Перенос модели на GPU, если доступно +# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +# model.to(device) +# model.half().to(device) # Load in half precision + +inputs = processor(images=image, return_tensors="pt") +outputs = model(**inputs) +logits = outputs.logits + +class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] +preds = torch.argmax(logits, dim=1) +print(preds.cpu().numpy()) +# model predicts one of the 1000 ImageNet classes +predicted_class_idx = logits.argmax(-1).item() +print("Predicted class:", model.config.id2label[predicted_class_idx]) diff --git a/cifar10.py b/cifar10.py index 55266b6..6b85d27 100644 --- a/cifar10.py +++ b/cifar10.py @@ -42,16 +42,32 @@ model.compile(optimizer='adam', metrics=['accuracy']) # Train the model -history = model.fit(train_images, train_labels, epochs=10, +history = model.fit(train_images, train_labels, epochs=20, validation_data=(test_images, test_labels)) -# Evaluate the model -plt.plot(history.history['accuracy'], label='accuracy') -plt.plot(history.history['val_accuracy'], label = 'val_accuracy') +# Оценка модели +test_loss, test_acc = model.evaluate(train_images, train_labels, verbose=2) +print(f"Test accuracy: {test_acc:.4f}") + +# Построение графиков +plt.figure(figsize=(12, 4)) + +# График точности +plt.subplot(1, 2, 1) +plt.plot(history.history['accuracy'], label='train accuracy') +plt.plot(history.history['val_accuracy'], label='val accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') -plt.ylim([0, 1]) plt.legend(loc='lower right') +plt.title('Training and Validation Accuracy') -test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) -print(test_acc) +# График потерь +plt.subplot(1, 2, 2) +plt.plot(history.history['loss'], label='train loss') +plt.plot(history.history['val_loss'], label='val loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.legend(loc='upper right') +plt.title('Training and Validation Loss') + +plt.show() diff --git a/cifral10_2.py b/cifral10_2.py index b9c3f54..3ed5f24 100644 --- a/cifral10_2.py +++ b/cifral10_2.py @@ -27,6 +27,7 @@ model = models.Sequential([ layers.Conv2D(64, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(64, (3, 3), activation='relu'), + layers.Flatten(), layers.Dense(64, activation='relu'), layers.Dense(10) diff --git a/nateraw.py b/nateraw.py new file mode 100644 index 0000000..3572d7f --- /dev/null +++ b/nateraw.py @@ -0,0 +1,26 @@ +from transformers import ViTFeatureExtractor, ViTForImageClassification +from PIL import Image +import requests +from tensorflow.keras import datasets, layers, models +import matplotlib.pyplot as plt + +# Load the CIFAR-10 dataset +(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() + +# Normalize pixel values to be between 0 and 1 +train_images, test_images = train_images / 255.0, test_images / 255.0 + +url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' +image = Image.open(requests.get(url, stream=True).raw) +feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10') +model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10') +inputs = feature_extractor(images=test_images[10], return_tensors="pt") +outputs = model(**inputs) +preds = outputs.logits.argmax(dim=1) + +classes = [ + 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' +] +print(classes[test_labels[10][0]]) +print(classes[preds[0]]) + diff --git a/readme.ipynb b/readme.ipynb index fdbf189..055b251 100644 --- a/readme.ipynb +++ b/readme.ipynb @@ -8,11 +8,11 @@ "\n", "## Введение\n", "\n", - "Задача: Написать нейронную сеть для распознования набора данных Cifral10\n", + "Задача: Написать нейронную сеть для распознования набора данных Cifar10\n", "\n", "## Загрузка и предобработка данных\n", "\n", - "В качестве набора данных для обучения используется набор данных (Cifral10)[https://www.cs.toronto.edu/%7Ekriz/cifar.html]" + "В качестве набора данных для обучения используется набор данных (Cifar10)[https://www.cs.toronto.edu/%7Ekriz/cifar.html]" ] }, { @@ -171,6 +171,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "## Результаты\n", + "\n", + "При такой модели на тестовых данных получено в районе 70% (максимум 71% на 11й эпохе). При этом тренировачные данные показывали 90% на 20й эпохе" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## google/vit-base-patch16-224\n", + "\n", + "При попытке использовать готовую модель google/vit-base-patch16-224 точность модель составила 10%. При попытке дообучить модель точность практически не изменилась. Полагаю не разборался как проверить данную модель на данных CIFAR10." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Модель: google/vit-base-patch16-224\n", "Test Accuracy: 0.1000\n", "Classification Report:\n", " precision recall f1-score support\n", @@ -192,6 +211,7 @@ "\n", "---\n", "\n", + "Дообученная модель google/vit-base-patch16-224 на данных Cifar10\n", "Test Accuracy: 0.1018\n", "Classification Report:\n", " precision recall f1-score support\n", @@ -209,10 +229,7 @@ "\n", " accuracy 0.10 10000\n", " macro avg 0.11 0.10 0.09 10000\n", - "weighted avg 0.11 0.10 0.09 10000\n", - "\n", - "\n", - "Process finished with exit code 0" + "weighted avg 0.11 0.10 0.09 10000" ] } ],