Browse Source

Разные попытки разобраться с ViT16

master
Bogdan Zuy 11 months ago
parent
commit
7c63dde1b9
  1. 7
      README.md
  2. 47
      ViT16.py
  3. 59
      ViT16_second.py
  4. 43
      ViT16_single.py
  5. 30
      cifar10.py
  6. 1
      cifral10_2.py
  7. 26
      nateraw.py
  8. 29
      readme.ipynb

7
README.md

@ -1,4 +1,9 @@
### Dataset source ## Dataset source
- https://www.cs.toronto.edu/%7Ekriz/cifar.html - https://www.cs.toronto.edu/%7Ekriz/cifar.html
- https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz (c58f30108f718f92721af3b95e74349a) - https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz (c58f30108f718f92721af3b95e74349a)
## Результат
Модель на 3х сверточных слоях на тестовых данных: 71% (cifar10.py)
Модель google/vit-base-patch16-224: 10% (вероятнее всего не разаборался как корректно ее перенастроить на классификацию по 10 категориям) (ViT16.py)

47
ViT16.py

@ -3,11 +3,13 @@ import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor
import matplotlib.pyplot as plt
import torch.nn.functional as F import torch.nn.functional as F
from sklearn.metrics import accuracy_score, classification_report from sklearn.metrics import accuracy_score, classification_report
from torch.optim import AdamW from torch.optim import AdamW
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
#import os #import os
import numpy as np
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
@ -35,16 +37,18 @@ transform = transforms.Compose([
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform) train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
"""
# Дообучение модели
# Выполняется 2 дня
# Оптимизатор и функция потерь # Оптимизатор и функция потерь
optimizer = AdamW(model.parameters(), lr=1e-4) optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss() criterion = CrossEntropyLoss()
# Дообучение модели
model.train() model.train()
num_epochs = 10 num_epochs = 10
@ -82,6 +86,7 @@ for epoch in range(num_epochs):
# # Печать средней потери за эпоху # # Печать средней потери за эпоху
# avg_epoch_loss = epoch_loss / len(train_loader) # avg_epoch_loss = epoch_loss / len(train_loader)
# print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss}') # print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss}')
"""
# Функция для оценки модели # Функция для оценки модели
def evaluate_model(model, data_loader, device): def evaluate_model(model, data_loader, device):
@ -111,3 +116,35 @@ accuracy, report = evaluate_model(model, test_loader, device)
print(f'Test Accuracy: {accuracy:.4f}') print(f'Test Accuracy: {accuracy:.4f}')
print('Classification Report:') print('Classification Report:')
print(report) print(report)
# Загрузка 25 тестовых изображений
test_images = []
true_labels = []
predicted_labels = []
model.eval()
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images).logits
preds = torch.argmax(outputs, dim=1)
test_images.extend(images.cpu())
true_labels.extend(labels.cpu())
predicted_labels.extend(preds.cpu())
if len(test_images) >= 25:
break
fig, axes = plt.subplots(5, 5, figsize=(15, 15))
axes = axes.ravel()
for i in range(5 * 5):
axes[i].imshow(np.transpose(test_images[:25][i].cpu().numpy(), (1, 2, 0)))
axes[i].set_title(f'True: {train_dataset.classes[true_labels[i]]}\nPred: {train_dataset.classes[predicted_labels[i]]}')
axes[i].axis('off')
plt.subplots_adjust(hspace=0.4)
plt.show()

59
ViT16_second.py

@ -0,0 +1,59 @@
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification, ViTImageProcessor
from datasets import load_dataset
from tqdm import tqdm
# Load CIFAR-10 dataset
dataset = load_dataset('cifar10')
# Define transformations for the dataset
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
# Apply transformations to the dataset
def preprocess_function(examples):
examples['pixel_values'] = [transform(image.convert("RGB")) for image in examples['img']]
return examples
encoded_dataset = dataset.with_transform(preprocess_function)
# Create DataLoader for the test set
def collate_fn(batch):
pixel_values = torch.stack([item['pixel_values'] for item in batch])
labels = torch.tensor([item['label'] for item in batch])
return {'pixel_values': pixel_values, 'label': labels}
test_loader = DataLoader(encoded_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn)
# Load the pre-trained ViT model
model_name = 'google/vit-base-patch16-224'
image_processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in tqdm(test_loader):
images = batch['pixel_values'].to(device)
labels = batch['label'].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.logits, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy of the model on the CIFAR-10 test images: {accuracy:.2f}%')

43
ViT16_single.py

@ -0,0 +1,43 @@
import torchvision.transforms as transforms
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
classes = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
]
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png'
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck10.png'
image = Image.open(requests.get(url, stream=True).raw)
test_dataset = CIFAR10(root='./data', train=False, download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# Изменение количества классов на 10 (для CIFAR-10)
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)
# Перенос модели на GPU, если доступно
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model.to(device)
# model.half().to(device) # Load in half precision
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
preds = torch.argmax(logits, dim=1)
print(preds.cpu().numpy())
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

30
cifar10.py

@ -42,16 +42,32 @@ model.compile(optimizer='adam',
metrics=['accuracy']) metrics=['accuracy'])
# Train the model # Train the model
history = model.fit(train_images, train_labels, epochs=10, history = model.fit(train_images, train_labels, epochs=20,
validation_data=(test_images, test_labels)) validation_data=(test_images, test_labels))
# Evaluate the model # Оценка модели
plt.plot(history.history['accuracy'], label='accuracy') test_loss, test_acc = model.evaluate(train_images, train_labels, verbose=2)
plt.plot(history.history['val_accuracy'], label = 'val_accuracy') print(f"Test accuracy: {test_acc:.4f}")
# Построение графиков
plt.figure(figsize=(12, 4))
# График точности
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='val accuracy')
plt.xlabel('Epoch') plt.xlabel('Epoch')
plt.ylabel('Accuracy') plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right') plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) # График потерь
print(test_acc) plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

1
cifral10_2.py

@ -27,6 +27,7 @@ model = models.Sequential([
layers.Conv2D(64, (3, 3), activation='relu'), layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)), layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'), layers.Conv2D(64, (3, 3), activation='relu'),
layers.Flatten(), layers.Flatten(),
layers.Dense(64, activation='relu'), layers.Dense(64, activation='relu'),
layers.Dense(10) layers.Dense(10)

26
nateraw.py

@ -0,0 +1,26 @@
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
# Load the CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10')
model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10')
inputs = feature_extractor(images=test_images[10], return_tensors="pt")
outputs = model(**inputs)
preds = outputs.logits.argmax(dim=1)
classes = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
]
print(classes[test_labels[10][0]])
print(classes[preds[0]])

29
readme.ipynb

@ -8,11 +8,11 @@
"\n", "\n",
"## Введение\n", "## Введение\n",
"\n", "\n",
"Задача: Написать нейронную сеть для распознования набора данных Cifral10\n", "Задача: Написать нейронную сеть для распознования набора данных Cifar10\n",
"\n", "\n",
"## Загрузка и предобработка данных\n", "## Загрузка и предобработка данных\n",
"\n", "\n",
"В качестве набора данных для обучения используется набор данных (Cifral10)[https://www.cs.toronto.edu/%7Ekriz/cifar.html]" "В качестве набора данных для обучения используется набор данных (Cifar10)[https://www.cs.toronto.edu/%7Ekriz/cifar.html]"
] ]
}, },
{ {
@ -171,6 +171,25 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Результаты\n",
"\n",
"При такой модели на тестовых данных получено в районе 70% (максимум 71% на 11й эпохе). При этом тренировачные данные показывали 90% на 20й эпохе"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## google/vit-base-patch16-224\n",
"\n",
"При попытке использовать готовую модель google/vit-base-patch16-224 точность модель составила 10%. При попытке дообучить модель точность практически не изменилась. Полагаю не разборался как проверить данную модель на данных CIFAR10."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Модель: google/vit-base-patch16-224\n",
"Test Accuracy: 0.1000\n", "Test Accuracy: 0.1000\n",
"Classification Report:\n", "Classification Report:\n",
" precision recall f1-score support\n", " precision recall f1-score support\n",
@ -192,6 +211,7 @@
"\n", "\n",
"---\n", "---\n",
"\n", "\n",
"Дообученная модель google/vit-base-patch16-224 на данных Cifar10\n",
"Test Accuracy: 0.1018\n", "Test Accuracy: 0.1018\n",
"Classification Report:\n", "Classification Report:\n",
" precision recall f1-score support\n", " precision recall f1-score support\n",
@ -209,10 +229,7 @@
"\n", "\n",
" accuracy 0.10 10000\n", " accuracy 0.10 10000\n",
" macro avg 0.11 0.10 0.09 10000\n", " macro avg 0.11 0.10 0.09 10000\n",
"weighted avg 0.11 0.10 0.09 10000\n", "weighted avg 0.11 0.10 0.09 10000"
"\n",
"\n",
"Process finished with exit code 0"
] ]
} }
], ],

Loading…
Cancel
Save