You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.3 KiB
73 lines
2.3 KiB
import tensorflow as tf |
|
from tensorflow.keras import datasets, layers, models |
|
import matplotlib.pyplot as plt |
|
|
|
print("TensorFlow version:", tf.__version__) |
|
print("CUDA runtime version:", tf.sysconfig.get_build_info()['cuda_version']) |
|
|
|
# Load the CIFAR-10 dataset |
|
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() |
|
|
|
# Normalize pixel values to be between 0 and 1 |
|
train_images, test_images = train_images / 255.0, test_images / 255.0 |
|
|
|
# Verify the data |
|
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
|
plt.figure(figsize=(10,10)) |
|
for i in range(25): |
|
plt.subplot(5,5,i+1) |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
plt.grid(False) |
|
plt.imshow(train_images[i]) |
|
plt.xlabel(class_names[train_labels[i][0]]) |
|
plt.show() |
|
|
|
# Build the CNN model |
|
model = models.Sequential() |
|
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))) |
|
model.add(layers.MaxPooling2D((2, 2))) |
|
model.add(layers.Conv2D(64, (3, 3), activation='relu')) |
|
model.add(layers.MaxPooling2D((2, 2))) |
|
model.add(layers.Conv2D(64, (3, 3), activation='relu')) |
|
|
|
model.add(layers.Flatten()) |
|
model.add(layers.Dense(64, activation='relu')) |
|
model.add(layers.Dense(10)) |
|
|
|
# Compile the model |
|
model.compile(optimizer='adam', |
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
|
metrics=['accuracy']) |
|
|
|
# Train the model |
|
history = model.fit(train_images, train_labels, epochs=20, |
|
validation_data=(test_images, test_labels)) |
|
|
|
# Оценка модели |
|
test_loss, test_acc = model.evaluate(train_images, train_labels, verbose=2) |
|
print(f"Test accuracy: {test_acc:.4f}") |
|
|
|
# Построение графиков |
|
plt.figure(figsize=(12, 4)) |
|
|
|
# График точности |
|
plt.subplot(1, 2, 1) |
|
plt.plot(history.history['accuracy'], label='train accuracy') |
|
plt.plot(history.history['val_accuracy'], label='val accuracy') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Accuracy') |
|
plt.legend(loc='lower right') |
|
plt.title('Training and Validation Accuracy') |
|
|
|
# График потерь |
|
plt.subplot(1, 2, 2) |
|
plt.plot(history.history['loss'], label='train loss') |
|
plt.plot(history.history['val_loss'], label='val loss') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Loss') |
|
plt.legend(loc='upper right') |
|
plt.title('Training and Validation Loss') |
|
|
|
plt.show()
|
|
|