8 changed files with 223 additions and 19 deletions
@ -1,4 +1,9 @@
|
||||
### Dataset source |
||||
## Dataset source |
||||
|
||||
- https://www.cs.toronto.edu/%7Ekriz/cifar.html |
||||
- https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz (c58f30108f718f92721af3b95e74349a) |
||||
|
||||
## Результат |
||||
|
||||
Модель на 3х сверточных слоях на тестовых данных: 71% (cifar10.py) |
||||
Модель google/vit-base-patch16-224: 10% (вероятнее всего не разаборался как корректно ее перенастроить на классификацию по 10 категориям) (ViT16.py) |
||||
@ -0,0 +1,59 @@
|
||||
import torch |
||||
from torch.utils.data import DataLoader |
||||
from torchvision import transforms |
||||
from transformers import ViTForImageClassification, ViTImageProcessor |
||||
from datasets import load_dataset |
||||
from tqdm import tqdm |
||||
|
||||
# Load CIFAR-10 dataset |
||||
dataset = load_dataset('cifar10') |
||||
|
||||
# Define transformations for the dataset |
||||
transform = transforms.Compose([ |
||||
transforms.Resize((224, 224)), # Resize images to 224x224 |
||||
transforms.ToTensor(), |
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) |
||||
]) |
||||
|
||||
# Apply transformations to the dataset |
||||
def preprocess_function(examples): |
||||
examples['pixel_values'] = [transform(image.convert("RGB")) for image in examples['img']] |
||||
return examples |
||||
|
||||
encoded_dataset = dataset.with_transform(preprocess_function) |
||||
|
||||
# Create DataLoader for the test set |
||||
def collate_fn(batch): |
||||
pixel_values = torch.stack([item['pixel_values'] for item in batch]) |
||||
labels = torch.tensor([item['label'] for item in batch]) |
||||
return {'pixel_values': pixel_values, 'label': labels} |
||||
|
||||
test_loader = DataLoader(encoded_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn) |
||||
|
||||
# Load the pre-trained ViT model |
||||
model_name = 'google/vit-base-patch16-224' |
||||
image_processor = ViTImageProcessor.from_pretrained(model_name) |
||||
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True) |
||||
|
||||
# Move model to GPU if available |
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||
model.to(device) |
||||
|
||||
# Evaluate the model |
||||
model.eval() |
||||
correct = 0 |
||||
total = 0 |
||||
|
||||
with torch.no_grad(): |
||||
for batch in tqdm(test_loader): |
||||
images = batch['pixel_values'].to(device) |
||||
labels = batch['label'].to(device) |
||||
|
||||
outputs = model(images) |
||||
_, predicted = torch.max(outputs.logits, 1) |
||||
|
||||
total += labels.size(0) |
||||
correct += (predicted == labels).sum().item() |
||||
|
||||
accuracy = 100 * correct / total |
||||
print(f'Accuracy of the model on the CIFAR-10 test images: {accuracy:.2f}%') |
||||
@ -0,0 +1,43 @@
|
||||
import torchvision.transforms as transforms |
||||
from transformers import ViTImageProcessor, ViTForImageClassification |
||||
from PIL import Image |
||||
import requests |
||||
import torch |
||||
from torchvision.datasets import CIFAR10 |
||||
from torch.utils.data import DataLoader |
||||
|
||||
classes = [ |
||||
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' |
||||
] |
||||
|
||||
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg' |
||||
# image = Image.open(requests.get(url, stream=True).raw) |
||||
|
||||
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' |
||||
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck10.png' |
||||
image = Image.open(requests.get(url, stream=True).raw) |
||||
|
||||
test_dataset = CIFAR10(root='./data', train=False, download=True) |
||||
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) |
||||
|
||||
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') |
||||
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') |
||||
|
||||
# Изменение количества классов на 10 (для CIFAR-10) |
||||
model.classifier = torch.nn.Linear(model.classifier.in_features, 10) |
||||
|
||||
# Перенос модели на GPU, если доступно |
||||
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||
# model.to(device) |
||||
# model.half().to(device) # Load in half precision |
||||
|
||||
inputs = processor(images=image, return_tensors="pt") |
||||
outputs = model(**inputs) |
||||
logits = outputs.logits |
||||
|
||||
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
||||
preds = torch.argmax(logits, dim=1) |
||||
print(preds.cpu().numpy()) |
||||
# model predicts one of the 1000 ImageNet classes |
||||
predicted_class_idx = logits.argmax(-1).item() |
||||
print("Predicted class:", model.config.id2label[predicted_class_idx]) |
||||
@ -0,0 +1,26 @@
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification |
||||
from PIL import Image |
||||
import requests |
||||
from tensorflow.keras import datasets, layers, models |
||||
import matplotlib.pyplot as plt |
||||
|
||||
# Load the CIFAR-10 dataset |
||||
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() |
||||
|
||||
# Normalize pixel values to be between 0 and 1 |
||||
train_images, test_images = train_images / 255.0, test_images / 255.0 |
||||
|
||||
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' |
||||
image = Image.open(requests.get(url, stream=True).raw) |
||||
feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10') |
||||
model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10') |
||||
inputs = feature_extractor(images=test_images[10], return_tensors="pt") |
||||
outputs = model(**inputs) |
||||
preds = outputs.logits.argmax(dim=1) |
||||
|
||||
classes = [ |
||||
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' |
||||
] |
||||
print(classes[test_labels[10][0]]) |
||||
print(classes[preds[0]]) |
||||
|
||||
Loading…
Reference in new issue