Модель распознования изображений из набора данных CIFAR10
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

86 lines
2.8 KiB

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
import time
# Загрузка данных CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Нормализация
x_train, x_test = x_train / 255.0, x_test / 255.0
# Аугментация
datagen = ImageDataGenerator(
rotation_range=15,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True
)
datagen.fit(x_train)
# Создание модели
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10)
])
# Компиляция модели
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Обучение модели
start_time = time.time()
history = model.fit(datagen.flow(x_train, y_train, batch_size=64),
epochs=20,
validation_data=(x_test, y_test))
end_time = time.time()
# Время обучения
training_time = end_time - start_time
print(f"Training time: {training_time:.2f} seconds")
# Оценка модели
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc:.4f}")
# Построение графиков
plt.figure(figsize=(12, 4))
# График точности
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='val accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# График потерь
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
# Документация и комментарии
"""
Этот код загружает данные CIFAR-10, нормализует их, создает и обучает сверточную нейронную сеть.
Используется аугментация данных для улучшения обучения модели.
Модель оценивается на тестовых данных, и строятся графики точности и потерь для анализа обучения.
"""
# Сохранение модели
model.save('cifar10_model.h5')