Первая модель нейронной сети для распознования печатных цифр
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

78 lines
3.1 KiB

# is a library for the Python programming language, adding support for large,
# multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays.
import numpy as np
# Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. Matplotlib makes easy things easy and hard things possible.
import matplotlib.pyplot as plt
import networks
from utils import create_digit_image, add_noise
# Создание датасета
def create_dataset(num_samples=1000, noise_level=0.1):
images = []
labels = []
for _ in range(num_samples):
digit = np.random.randint(0, 10)
image = create_digit_image(digit)
noisy_image = add_noise(image, noise_level)
images.append(noisy_image.flatten())
labels.append(digit)
return np.array(images), np.array(labels)
# Создание тренировочного и тестового наборов данных
train_images, train_labels = create_dataset(num_samples=1000, noise_level=0.01)
test_images, test_labels = create_dataset(num_samples=200, noise_level=0.01)
# # Визуализация нескольких примеров (цифры)
# fig, axes = plt.subplots(1, 4, figsize=(10, 3))
# for i, ax in enumerate(axes):
# digit = np.random.randint(0, 10)
# image = create_digit_image(digit)
# ax.imshow(image, cmap='gray')
# ax.set_title(f'Label: {digit}')
# ax.axis('off')
# plt.show()
# # Визуализация нескольких примеров (данные обучения)
# fig, axes = plt.subplots(1, 5, figsize=(10, 3))
# for i, ax in enumerate(axes):
# ax.imshow(train_images[i].reshape(5, 3), cmap='gray')
# ax.set_title(f'Label: {train_labels[i]}')
# ax.axis('off')
# plt.show()
#
# # Визуализация нескольких примеров (данные проверки)
# fig, axes = plt.subplots(1, 5, figsize=(10, 3))
# for i, ax in enumerate(axes):
# ax.imshow(test_images[i].reshape(5, 3), cmap='gray')
# ax.set_title(f'Label: {test_labels[i]}')
# ax.axis('off')
# plt.show()
def run():
# Инициализация модели
input_size = 5 * 3
hidden_size = 64
output_size = 10
model = networks.SimpleNeuralNetwork(input_size, hidden_size, output_size) # SimpleNeuralNetwork
# Обучение модели
model.train(train_images, train_labels, learning_rate=0.01, epochs=5000)
# Оценка модели
model.evaluate(train_images[0], train_labels[0])
# Построение графиков
model.plot_metrics()
# Предсказание на тестовых данных
predictions = model.predict(test_images)
accuracy = np.mean(predictions == test_labels)
print(f'Test accuracy: {accuracy}')
# Визуализация нескольких примеров
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
for i, ax in enumerate(axes):
ax.imshow(test_images[i].reshape(5, 3), cmap='gray')
ax.set_title(f'True: {test_labels[i]}, Pred: {predictions[i]}')
ax.axis('off')
plt.show()