You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
1.8 KiB
57 lines
1.8 KiB
import tensorflow as tf |
|
from tensorflow.keras import datasets, layers, models |
|
import matplotlib.pyplot as plt |
|
|
|
print("TensorFlow version:", tf.__version__) |
|
print("CUDA runtime version:", tf.sysconfig.get_build_info()['cuda_version']) |
|
|
|
# Load the CIFAR-10 dataset |
|
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() |
|
|
|
# Normalize pixel values to be between 0 and 1 |
|
train_images, test_images = train_images / 255.0, test_images / 255.0 |
|
|
|
# Verify the data |
|
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
|
plt.figure(figsize=(10,10)) |
|
for i in range(25): |
|
plt.subplot(5,5,i+1) |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
plt.grid(False) |
|
plt.imshow(train_images[i]) |
|
plt.xlabel(class_names[train_labels[i][0]]) |
|
plt.show() |
|
|
|
# Build the CNN model |
|
model = models.Sequential() |
|
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))) |
|
model.add(layers.MaxPooling2D((2, 2))) |
|
model.add(layers.Conv2D(64, (3, 3), activation='relu')) |
|
model.add(layers.MaxPooling2D((2, 2))) |
|
model.add(layers.Conv2D(64, (3, 3), activation='relu')) |
|
|
|
model.add(layers.Flatten()) |
|
model.add(layers.Dense(64, activation='relu')) |
|
model.add(layers.Dense(10)) |
|
|
|
# Compile the model |
|
model.compile(optimizer='adam', |
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
|
metrics=['accuracy']) |
|
|
|
# Train the model |
|
history = model.fit(train_images, train_labels, epochs=10, |
|
validation_data=(test_images, test_labels)) |
|
|
|
# Evaluate the model |
|
plt.plot(history.history['accuracy'], label='accuracy') |
|
plt.plot(history.history['val_accuracy'], label = 'val_accuracy') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Accuracy') |
|
plt.ylim([0, 1]) |
|
plt.legend(loc='lower right') |
|
|
|
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) |
|
print(test_acc)
|
|
|