Модель распознования изображений из набора данных CIFAR10
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

131 lines
4.4 KiB

import os
import pickle
import numpy as np
import urllib.request
import tarfile
import matplotlib.pyplot as plt
def download_and_extract_cifar10(data_dir='cifar-10-batches-py'):
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
filename = 'cifar-10-python.tar.gz'
if not os.path.exists(data_dir):
urllib.request.urlretrieve(url, filename)
with tarfile.open(filename, 'r:gz') as tar:
tar.extractall()
return data_dir
data_dir = download_and_extract_cifar10()
def load_cifar10(data_dir):
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
train_data = []
train_labels = []
for i in range(1, 6):
batch = unpickle(os.path.join(data_dir, 'data_batch_' + str(i)))
train_data.append(batch[b'data'])
train_labels.append(batch[b'labels'])
train_data = np.concatenate(train_data).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
train_labels = np.concatenate(train_labels)
test_batch = unpickle(os.path.join(data_dir, 'test_batch'))
test_data = test_batch[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
test_labels = np.array(test_batch[b'labels'])
return train_data, train_labels, test_data, test_labels
train_data, train_labels, test_data, test_labels = load_cifar10(data_dir)
# Normalize the data
train_data = train_data / 255.0
test_data = test_data / 255.0
class SimpleNN:
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
# Initialize weights and biases
self.W1 = np.random.randn(input_size, hidden_size) * 0.01
self.b1 = np.zeros((1, hidden_size))
self.W2 = np.random.randn(hidden_size, output_size) * 0.01
self.b2 = np.zeros((1, output_size))
def forward(self, X):
self.Z1 = np.dot(X, self.W1) + self.b1
self.A1 = np.maximum(0, self.Z1) # ReLU activation
self.Z2 = np.dot(self.A1, self.W2) + self.b2
self.A2 = np.exp(self.Z2) / np.sum(np.exp(self.Z2), axis=1, keepdims=True) # Softmax activation
return self.A2
def backward(self, X, y, output):
m = y.shape[0]
self.dZ2 = output
self.dZ2[range(m), y] -= 1
self.dW2 = (1 / m) * np.dot(self.A1.T, self.dZ2)
self.db2 = (1 / m) * np.sum(self.dZ2, axis=0, keepdims=True)
self.dA1 = np.dot(self.dZ2, self.W2.T)
self.dZ1 = np.array(self.dA1, copy=True)
self.dZ1[self.Z1 <= 0] = 0
self.dW1 = (1 / m) * np.dot(X.T, self.dZ1)
self.db1 = (1 / m) * np.sum(self.dZ1, axis=0, keepdims=True)
def update_parameters(self, learning_rate):
self.W1 -= learning_rate * self.dW1
self.b1 -= learning_rate * self.db1
self.W2 -= learning_rate * self.dW2
self.b2 -= learning_rate * self.db2
def compute_loss(self, y, output):
m = y.shape[0]
log_likelihood = -np.log(output[range(m), y])
loss = np.sum(log_likelihood) / m
return loss
def train(model, X, y, learning_rate=0.01, epochs=10):
for epoch in range(epochs):
output = model.forward(X)
loss = model.compute_loss(y, output)
model.backward(X, y, output)
model.update_parameters(learning_rate)
if epoch % 1 == 0:
print(f'Epoch {epoch + 1}, Loss: {loss}')
# Flatten the input data
train_data_flat = train_data.reshape(-1, 32*32*3)
test_data_flat = test_data.reshape(-1, 32*32*3)
# Initialize the model
model = SimpleNN(input_size=32*32*3, hidden_size=128, output_size=10)
# Train the model
train(model, train_data_flat, train_labels, learning_rate=0.01, epochs=10)
def evaluate(model, X, y):
output = model.forward(X)
predictions = np.argmax(output, axis=1)
accuracy = np.mean(predictions == y)
return accuracy
# Evaluate the model on the test data
accuracy = evaluate(model, test_data_flat, test_labels)
test_data_reshaped = test_data_flat.reshape(-1, 32, 32, 3)
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(test_data_reshaped[i])
plt.xlabel(class_names[test_labels[i]])
plt.show()
print(f'Test Accuracy: {accuracy * 100:.2f}%')