8 changed files with 223 additions and 19 deletions
@ -1,4 +1,9 @@ |
|||||||
### Dataset source |
## Dataset source |
||||||
|
|
||||||
- https://www.cs.toronto.edu/%7Ekriz/cifar.html |
- https://www.cs.toronto.edu/%7Ekriz/cifar.html |
||||||
- https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz (c58f30108f718f92721af3b95e74349a) |
- https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz (c58f30108f718f92721af3b95e74349a) |
||||||
|
|
||||||
|
## Результат |
||||||
|
|
||||||
|
Модель на 3х сверточных слоях на тестовых данных: 71% (cifar10.py) |
||||||
|
Модель google/vit-base-patch16-224: 10% (вероятнее всего не разаборался как корректно ее перенастроить на классификацию по 10 категориям) (ViT16.py) |
||||||
@ -0,0 +1,59 @@ |
|||||||
|
import torch |
||||||
|
from torch.utils.data import DataLoader |
||||||
|
from torchvision import transforms |
||||||
|
from transformers import ViTForImageClassification, ViTImageProcessor |
||||||
|
from datasets import load_dataset |
||||||
|
from tqdm import tqdm |
||||||
|
|
||||||
|
# Load CIFAR-10 dataset |
||||||
|
dataset = load_dataset('cifar10') |
||||||
|
|
||||||
|
# Define transformations for the dataset |
||||||
|
transform = transforms.Compose([ |
||||||
|
transforms.Resize((224, 224)), # Resize images to 224x224 |
||||||
|
transforms.ToTensor(), |
||||||
|
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) |
||||||
|
]) |
||||||
|
|
||||||
|
# Apply transformations to the dataset |
||||||
|
def preprocess_function(examples): |
||||||
|
examples['pixel_values'] = [transform(image.convert("RGB")) for image in examples['img']] |
||||||
|
return examples |
||||||
|
|
||||||
|
encoded_dataset = dataset.with_transform(preprocess_function) |
||||||
|
|
||||||
|
# Create DataLoader for the test set |
||||||
|
def collate_fn(batch): |
||||||
|
pixel_values = torch.stack([item['pixel_values'] for item in batch]) |
||||||
|
labels = torch.tensor([item['label'] for item in batch]) |
||||||
|
return {'pixel_values': pixel_values, 'label': labels} |
||||||
|
|
||||||
|
test_loader = DataLoader(encoded_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn) |
||||||
|
|
||||||
|
# Load the pre-trained ViT model |
||||||
|
model_name = 'google/vit-base-patch16-224' |
||||||
|
image_processor = ViTImageProcessor.from_pretrained(model_name) |
||||||
|
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True) |
||||||
|
|
||||||
|
# Move model to GPU if available |
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||||
|
model.to(device) |
||||||
|
|
||||||
|
# Evaluate the model |
||||||
|
model.eval() |
||||||
|
correct = 0 |
||||||
|
total = 0 |
||||||
|
|
||||||
|
with torch.no_grad(): |
||||||
|
for batch in tqdm(test_loader): |
||||||
|
images = batch['pixel_values'].to(device) |
||||||
|
labels = batch['label'].to(device) |
||||||
|
|
||||||
|
outputs = model(images) |
||||||
|
_, predicted = torch.max(outputs.logits, 1) |
||||||
|
|
||||||
|
total += labels.size(0) |
||||||
|
correct += (predicted == labels).sum().item() |
||||||
|
|
||||||
|
accuracy = 100 * correct / total |
||||||
|
print(f'Accuracy of the model on the CIFAR-10 test images: {accuracy:.2f}%') |
||||||
@ -0,0 +1,43 @@ |
|||||||
|
import torchvision.transforms as transforms |
||||||
|
from transformers import ViTImageProcessor, ViTForImageClassification |
||||||
|
from PIL import Image |
||||||
|
import requests |
||||||
|
import torch |
||||||
|
from torchvision.datasets import CIFAR10 |
||||||
|
from torch.utils.data import DataLoader |
||||||
|
|
||||||
|
classes = [ |
||||||
|
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' |
||||||
|
] |
||||||
|
|
||||||
|
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg' |
||||||
|
# image = Image.open(requests.get(url, stream=True).raw) |
||||||
|
|
||||||
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' |
||||||
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck10.png' |
||||||
|
image = Image.open(requests.get(url, stream=True).raw) |
||||||
|
|
||||||
|
test_dataset = CIFAR10(root='./data', train=False, download=True) |
||||||
|
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) |
||||||
|
|
||||||
|
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') |
||||||
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') |
||||||
|
|
||||||
|
# Изменение количества классов на 10 (для CIFAR-10) |
||||||
|
model.classifier = torch.nn.Linear(model.classifier.in_features, 10) |
||||||
|
|
||||||
|
# Перенос модели на GPU, если доступно |
||||||
|
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||||
|
# model.to(device) |
||||||
|
# model.half().to(device) # Load in half precision |
||||||
|
|
||||||
|
inputs = processor(images=image, return_tensors="pt") |
||||||
|
outputs = model(**inputs) |
||||||
|
logits = outputs.logits |
||||||
|
|
||||||
|
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
||||||
|
preds = torch.argmax(logits, dim=1) |
||||||
|
print(preds.cpu().numpy()) |
||||||
|
# model predicts one of the 1000 ImageNet classes |
||||||
|
predicted_class_idx = logits.argmax(-1).item() |
||||||
|
print("Predicted class:", model.config.id2label[predicted_class_idx]) |
||||||
@ -0,0 +1,26 @@ |
|||||||
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
||||||
|
from PIL import Image |
||||||
|
import requests |
||||||
|
from tensorflow.keras import datasets, layers, models |
||||||
|
import matplotlib.pyplot as plt |
||||||
|
|
||||||
|
# Load the CIFAR-10 dataset |
||||||
|
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() |
||||||
|
|
||||||
|
# Normalize pixel values to be between 0 and 1 |
||||||
|
train_images, test_images = train_images / 255.0, test_images / 255.0 |
||||||
|
|
||||||
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' |
||||||
|
image = Image.open(requests.get(url, stream=True).raw) |
||||||
|
feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10') |
||||||
|
model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10') |
||||||
|
inputs = feature_extractor(images=test_images[10], return_tensors="pt") |
||||||
|
outputs = model(**inputs) |
||||||
|
preds = outputs.logits.argmax(dim=1) |
||||||
|
|
||||||
|
classes = [ |
||||||
|
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' |
||||||
|
] |
||||||
|
print(classes[test_labels[10][0]]) |
||||||
|
print(classes[preds[0]]) |
||||||
|
|
||||||
Loading…
Reference in new issue