You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
3.1 KiB
78 lines
3.1 KiB
# is a library for the Python programming language, adding support for large, |
|
# multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. |
|
import numpy as np |
|
# Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. Matplotlib makes easy things easy and hard things possible. |
|
import matplotlib.pyplot as plt |
|
import networks |
|
from utils import create_digit_image, add_noise |
|
|
|
# Создание датасета |
|
def create_dataset(num_samples=1000, noise_level=0.1): |
|
images = [] |
|
labels = [] |
|
for _ in range(num_samples): |
|
digit = np.random.randint(0, 10) |
|
image = create_digit_image(digit) |
|
noisy_image = add_noise(image, noise_level) |
|
images.append(noisy_image.flatten()) |
|
labels.append(digit) |
|
return np.array(images), np.array(labels) |
|
|
|
# Создание тренировочного и тестового наборов данных |
|
train_images, train_labels = create_dataset(num_samples=1000, noise_level=0.01) |
|
test_images, test_labels = create_dataset(num_samples=200, noise_level=0.01) |
|
|
|
# # Визуализация нескольких примеров (цифры) |
|
# fig, axes = plt.subplots(1, 4, figsize=(10, 3)) |
|
# for i, ax in enumerate(axes): |
|
# digit = np.random.randint(0, 10) |
|
# image = create_digit_image(digit) |
|
# ax.imshow(image, cmap='gray') |
|
# ax.set_title(f'Label: {digit}') |
|
# ax.axis('off') |
|
# plt.show() |
|
|
|
# # Визуализация нескольких примеров (данные обучения) |
|
# fig, axes = plt.subplots(1, 5, figsize=(10, 3)) |
|
# for i, ax in enumerate(axes): |
|
# ax.imshow(train_images[i].reshape(5, 3), cmap='gray') |
|
# ax.set_title(f'Label: {train_labels[i]}') |
|
# ax.axis('off') |
|
# plt.show() |
|
# |
|
# # Визуализация нескольких примеров (данные проверки) |
|
# fig, axes = plt.subplots(1, 5, figsize=(10, 3)) |
|
# for i, ax in enumerate(axes): |
|
# ax.imshow(test_images[i].reshape(5, 3), cmap='gray') |
|
# ax.set_title(f'Label: {test_labels[i]}') |
|
# ax.axis('off') |
|
# plt.show() |
|
|
|
def run(): |
|
# Инициализация модели |
|
input_size = 5 * 3 |
|
hidden_size = 64 |
|
output_size = 10 |
|
model = networks.SimpleNeuralNetwork(input_size, hidden_size, output_size) # SimpleNeuralNetwork |
|
|
|
# Обучение модели |
|
model.train(train_images, train_labels, learning_rate=0.01, epochs=5000) |
|
|
|
# Оценка модели |
|
model.evaluate(train_images[0], train_labels[0]) |
|
|
|
# Построение графиков |
|
model.plot_metrics() |
|
|
|
# Предсказание на тестовых данных |
|
predictions = model.predict(test_images) |
|
accuracy = np.mean(predictions == test_labels) |
|
print(f'Test accuracy: {accuracy}') |
|
|
|
# Визуализация нескольких примеров |
|
fig, axes = plt.subplots(1, 5, figsize=(10, 3)) |
|
for i, ax in enumerate(axes): |
|
ax.imshow(test_images[i].reshape(5, 3), cmap='gray') |
|
ax.set_title(f'True: {test_labels[i]}, Pred: {predictions[i]}') |
|
ax.axis('off') |
|
plt.show() |