You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
4.4 KiB
131 lines
4.4 KiB
import os |
|
import pickle |
|
import numpy as np |
|
import urllib.request |
|
import tarfile |
|
import matplotlib.pyplot as plt |
|
|
|
def download_and_extract_cifar10(data_dir='cifar-10-batches-py'): |
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' |
|
filename = 'cifar-10-python.tar.gz' |
|
if not os.path.exists(data_dir): |
|
urllib.request.urlretrieve(url, filename) |
|
with tarfile.open(filename, 'r:gz') as tar: |
|
tar.extractall() |
|
return data_dir |
|
|
|
data_dir = download_and_extract_cifar10() |
|
|
|
def load_cifar10(data_dir): |
|
def unpickle(file): |
|
with open(file, 'rb') as fo: |
|
dict = pickle.load(fo, encoding='bytes') |
|
return dict |
|
|
|
train_data = [] |
|
train_labels = [] |
|
for i in range(1, 6): |
|
batch = unpickle(os.path.join(data_dir, 'data_batch_' + str(i))) |
|
train_data.append(batch[b'data']) |
|
train_labels.append(batch[b'labels']) |
|
|
|
train_data = np.concatenate(train_data).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) |
|
train_labels = np.concatenate(train_labels) |
|
|
|
test_batch = unpickle(os.path.join(data_dir, 'test_batch')) |
|
test_data = test_batch[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) |
|
test_labels = np.array(test_batch[b'labels']) |
|
|
|
return train_data, train_labels, test_data, test_labels |
|
|
|
train_data, train_labels, test_data, test_labels = load_cifar10(data_dir) |
|
|
|
# Normalize the data |
|
train_data = train_data / 255.0 |
|
test_data = test_data / 255.0 |
|
|
|
class SimpleNN: |
|
def __init__(self, input_size, hidden_size, output_size): |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.output_size = output_size |
|
|
|
# Initialize weights and biases |
|
self.W1 = np.random.randn(input_size, hidden_size) * 0.01 |
|
self.b1 = np.zeros((1, hidden_size)) |
|
self.W2 = np.random.randn(hidden_size, output_size) * 0.01 |
|
self.b2 = np.zeros((1, output_size)) |
|
|
|
def forward(self, X): |
|
self.Z1 = np.dot(X, self.W1) + self.b1 |
|
self.A1 = np.maximum(0, self.Z1) # ReLU activation |
|
self.Z2 = np.dot(self.A1, self.W2) + self.b2 |
|
self.A2 = np.exp(self.Z2) / np.sum(np.exp(self.Z2), axis=1, keepdims=True) # Softmax activation |
|
return self.A2 |
|
|
|
def backward(self, X, y, output): |
|
m = y.shape[0] |
|
self.dZ2 = output |
|
self.dZ2[range(m), y] -= 1 |
|
self.dW2 = (1 / m) * np.dot(self.A1.T, self.dZ2) |
|
self.db2 = (1 / m) * np.sum(self.dZ2, axis=0, keepdims=True) |
|
self.dA1 = np.dot(self.dZ2, self.W2.T) |
|
self.dZ1 = np.array(self.dA1, copy=True) |
|
self.dZ1[self.Z1 <= 0] = 0 |
|
self.dW1 = (1 / m) * np.dot(X.T, self.dZ1) |
|
self.db1 = (1 / m) * np.sum(self.dZ1, axis=0, keepdims=True) |
|
|
|
def update_parameters(self, learning_rate): |
|
self.W1 -= learning_rate * self.dW1 |
|
self.b1 -= learning_rate * self.db1 |
|
self.W2 -= learning_rate * self.dW2 |
|
self.b2 -= learning_rate * self.db2 |
|
|
|
def compute_loss(self, y, output): |
|
m = y.shape[0] |
|
log_likelihood = -np.log(output[range(m), y]) |
|
loss = np.sum(log_likelihood) / m |
|
return loss |
|
|
|
def train(model, X, y, learning_rate=0.01, epochs=10): |
|
for epoch in range(epochs): |
|
output = model.forward(X) |
|
loss = model.compute_loss(y, output) |
|
model.backward(X, y, output) |
|
model.update_parameters(learning_rate) |
|
if epoch % 1 == 0: |
|
print(f'Epoch {epoch + 1}, Loss: {loss}') |
|
|
|
# Flatten the input data |
|
train_data_flat = train_data.reshape(-1, 32*32*3) |
|
test_data_flat = test_data.reshape(-1, 32*32*3) |
|
|
|
# Initialize the model |
|
model = SimpleNN(input_size=32*32*3, hidden_size=128, output_size=10) |
|
|
|
# Train the model |
|
train(model, train_data_flat, train_labels, learning_rate=0.01, epochs=10) |
|
|
|
def evaluate(model, X, y): |
|
output = model.forward(X) |
|
predictions = np.argmax(output, axis=1) |
|
accuracy = np.mean(predictions == y) |
|
return accuracy |
|
|
|
# Evaluate the model on the test data |
|
accuracy = evaluate(model, test_data_flat, test_labels) |
|
test_data_reshaped = test_data_flat.reshape(-1, 32, 32, 3) |
|
|
|
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
|
plt.figure(figsize=(10,10)) |
|
for i in range(25): |
|
plt.subplot(5,5,i+1) |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
plt.grid(False) |
|
plt.imshow(test_data_reshaped[i]) |
|
plt.xlabel(class_names[test_labels[i]]) |
|
plt.show() |
|
|
|
print(f'Test Accuracy: {accuracy * 100:.2f}%')
|
|
|