You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
2.8 KiB
87 lines
2.8 KiB
import tensorflow as tf |
|
from tensorflow.keras import layers, models |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import time |
|
|
|
# Загрузка данных CIFAR-10 |
|
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() |
|
|
|
# Нормализация |
|
x_train, x_test = x_train / 255.0, x_test / 255.0 |
|
|
|
# Аугментация |
|
datagen = ImageDataGenerator( |
|
rotation_range=15, |
|
width_shift_range=0.1, |
|
height_shift_range=0.1, |
|
horizontal_flip=True |
|
) |
|
datagen.fit(x_train) |
|
|
|
# Создание модели |
|
model = models.Sequential([ |
|
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), |
|
layers.MaxPooling2D((2, 2)), |
|
layers.Conv2D(64, (3, 3), activation='relu'), |
|
layers.MaxPooling2D((2, 2)), |
|
layers.Conv2D(64, (3, 3), activation='relu'), |
|
|
|
layers.Flatten(), |
|
layers.Dense(64, activation='relu'), |
|
layers.Dense(10) |
|
]) |
|
|
|
# Компиляция модели |
|
model.compile(optimizer='adam', |
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
|
metrics=['accuracy']) |
|
|
|
# Обучение модели |
|
start_time = time.time() |
|
history = model.fit(datagen.flow(x_train, y_train, batch_size=64), |
|
epochs=20, |
|
validation_data=(x_test, y_test)) |
|
end_time = time.time() |
|
|
|
# Время обучения |
|
training_time = end_time - start_time |
|
print(f"Training time: {training_time:.2f} seconds") |
|
|
|
# Оценка модели |
|
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2) |
|
print(f"Test accuracy: {test_acc:.4f}") |
|
|
|
# Построение графиков |
|
plt.figure(figsize=(12, 4)) |
|
|
|
# График точности |
|
plt.subplot(1, 2, 1) |
|
plt.plot(history.history['accuracy'], label='train accuracy') |
|
plt.plot(history.history['val_accuracy'], label='val accuracy') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Accuracy') |
|
plt.legend(loc='lower right') |
|
plt.title('Training and Validation Accuracy') |
|
|
|
# График потерь |
|
plt.subplot(1, 2, 2) |
|
plt.plot(history.history['loss'], label='train loss') |
|
plt.plot(history.history['val_loss'], label='val loss') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Loss') |
|
plt.legend(loc='upper right') |
|
plt.title('Training and Validation Loss') |
|
|
|
plt.show() |
|
|
|
# Документация и комментарии |
|
""" |
|
Этот код загружает данные CIFAR-10, нормализует их, создает и обучает сверточную нейронную сеть. |
|
Используется аугментация данных для улучшения обучения модели. |
|
Модель оценивается на тестовых данных, и строятся графики точности и потерь для анализа обучения. |
|
""" |
|
|
|
# Сохранение модели |
|
model.save('cifar10_model.h5')
|
|
|