Let’s walk through how to build a CNN from scratch using TensorFlow and Keras to classify images from the CIFAR-10 dataset. CIFAR-10 is a dataset containing 60,000 color images (32x32 pixels) across 10 classes.

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

TensorFlow has it built-in:

(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

The images are stored as NumPy arrays, and the labels range from 0 to 9, corresponding to different categories like airplanes, cats, and trucks.

Preprocess

Neural networks perform best when input data is normalized. Since pixel values range from 0 to 255, we scale them down to [0,1]. We also convert the labels into categorical format:

# Normalize the images to the range [0, 1]
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255

# One-hot encode the labels
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 4s 0us/step
print(f"Training samples: {train_images.shape[0]}")
print(f"Test samples: {test_images.shape[0]}")

print(f"Image shape: {train_images[0].shape}")  # Should be (32, 32, 3) (32x32 pixels, 3 channels)
print(f"Label: {train_labels[0]}")
Training samples: 50000
Test samples: 10000
Image shape: (32, 32, 3)
Label: [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]

Build the CNN model

Now, let’s architect our CNN. The idea is simple:

  1. Convolutional layers to extract features
  2. Pooling layers to downsample
  3. Dense layers to classify
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Input

model = models.Sequential([
    Input(shape=(32, 32, 3)),
    layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.2),

    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.3),

    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.4),

    layers.Flatten(),
    layers.Dense(256, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

Breaking it Down:

  • First Conv Layer: 32 filters, 3x3 kernel, ReLU activation, and input shape of (32,32,3)
  • MaxPooling Layer: Reduces spatial size, preventing overfitting
  • Repeat: Increase filter size progressively (64 → 128)
  • Flatten: Turns the feature map into a 1D vector
  • Dense Layers: Fully connected layers for classification
  • Softmax Output: Gives probability distribution over 10 classes

Data Augmentation

To improve generalization, let’s apply some random transformations to the training images:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    rotation_range=15,
    zoom_range=0.1,
    fill_mode='nearest'
)

datagen.fit(train_images)

Compile and train

from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-04)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
history_2 = model.fit(datagen.flow(train_images, train_labels, batch_size=128),
          epochs=100,
          validation_data=(test_images, test_labels),
          callbacks=[early_stopping, reduce_lr])
Epoch 1/100


/usr/local/lib/python3.11/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
  self._warn_if_super_not_called()


391/391 ━━━━━━━━━━━━━━━━━━━━ 57s 110ms/step - accuracy: 0.3175 - loss: 2.1918 - val_accuracy: 0.2486 - val_loss: 2.5132 - learning_rate: 0.0010
Epoch 2/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.5112 - loss: 1.3688 - val_accuracy: 0.6072 - val_loss: 1.0945 - learning_rate: 0.0010
Epoch 3/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.5946 - loss: 1.1477 - val_accuracy: 0.5630 - val_loss: 1.4816 - learning_rate: 0.0010
Epoch 4/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 40s 72ms/step - accuracy: 0.6409 - loss: 1.0067 - val_accuracy: 0.6870 - val_loss: 0.9069 - learning_rate: 0.0010
Epoch 5/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 72ms/step - accuracy: 0.6765 - loss: 0.9170 - val_accuracy: 0.7003 - val_loss: 0.8689 - learning_rate: 0.0010
Epoch 6/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 31s 79ms/step - accuracy: 0.7021 - loss: 0.8467 - val_accuracy: 0.7506 - val_loss: 0.7191 - learning_rate: 0.0010
Epoch 7/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 73ms/step - accuracy: 0.7209 - loss: 0.8030 - val_accuracy: 0.7133 - val_loss: 0.8441 - learning_rate: 0.0010
Epoch 8/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.7284 - loss: 0.7695 - val_accuracy: 0.7692 - val_loss: 0.6790 - learning_rate: 0.0010
Epoch 9/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.7452 - loss: 0.7343 - val_accuracy: 0.7563 - val_loss: 0.7184 - learning_rate: 0.0010
Epoch 10/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 73ms/step - accuracy: 0.7586 - loss: 0.7030 - val_accuracy: 0.7608 - val_loss: 0.7153 - learning_rate: 0.0010
Epoch 11/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 42s 75ms/step - accuracy: 0.7663 - loss: 0.6796 - val_accuracy: 0.7632 - val_loss: 0.7433 - learning_rate: 0.0010
Epoch 12/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 73ms/step - accuracy: 0.7734 - loss: 0.6537 - val_accuracy: 0.7750 - val_loss: 0.6597 - learning_rate: 0.0010
Epoch 13/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 72ms/step - accuracy: 0.7802 - loss: 0.6382 - val_accuracy: 0.7962 - val_loss: 0.6104 - learning_rate: 0.0010
Epoch 14/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 41s 72ms/step - accuracy: 0.7904 - loss: 0.6160 - val_accuracy: 0.7867 - val_loss: 0.6579 - learning_rate: 0.0010
Epoch 15/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 73ms/step - accuracy: 0.7939 - loss: 0.6026 - val_accuracy: 0.8176 - val_loss: 0.5443 - learning_rate: 0.0010
Epoch 16/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.7986 - loss: 0.5862 - val_accuracy: 0.8143 - val_loss: 0.5561 - learning_rate: 0.0010
Epoch 17/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 72ms/step - accuracy: 0.8044 - loss: 0.5669 - val_accuracy: 0.7965 - val_loss: 0.6193 - learning_rate: 0.0010
Epoch 18/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 72ms/step - accuracy: 0.8104 - loss: 0.5542 - val_accuracy: 0.8137 - val_loss: 0.5530 - learning_rate: 0.0010
Epoch 19/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.8077 - loss: 0.5544 - val_accuracy: 0.8201 - val_loss: 0.5278 - learning_rate: 0.0010
Epoch 20/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8082 - loss: 0.5553 - val_accuracy: 0.7874 - val_loss: 0.6586 - learning_rate: 0.0010
Epoch 21/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 72ms/step - accuracy: 0.8144 - loss: 0.5402 - val_accuracy: 0.8222 - val_loss: 0.5407 - learning_rate: 0.0010
Epoch 22/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 41s 73ms/step - accuracy: 0.8178 - loss: 0.5192 - val_accuracy: 0.8068 - val_loss: 0.5975 - learning_rate: 0.0010
Epoch 23/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 30s 77ms/step - accuracy: 0.8191 - loss: 0.5222 - val_accuracy: 0.8403 - val_loss: 0.4870 - learning_rate: 0.0010
Epoch 24/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 31s 78ms/step - accuracy: 0.8234 - loss: 0.5109 - val_accuracy: 0.8436 - val_loss: 0.4622 - learning_rate: 0.0010
Epoch 25/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8249 - loss: 0.5052 - val_accuracy: 0.8352 - val_loss: 0.4972 - learning_rate: 0.0010
Epoch 26/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 41s 75ms/step - accuracy: 0.8270 - loss: 0.4999 - val_accuracy: 0.8437 - val_loss: 0.4592 - learning_rate: 0.0010
Epoch 27/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.8307 - loss: 0.4923 - val_accuracy: 0.8508 - val_loss: 0.4435 - learning_rate: 0.0010
Epoch 28/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 30s 77ms/step - accuracy: 0.8351 - loss: 0.4802 - val_accuracy: 0.8163 - val_loss: 0.5592 - learning_rate: 0.0010
Epoch 29/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8325 - loss: 0.4802 - val_accuracy: 0.8377 - val_loss: 0.4978 - learning_rate: 0.0010
Epoch 30/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8359 - loss: 0.4730 - val_accuracy: 0.8449 - val_loss: 0.4665 - learning_rate: 0.0010
Epoch 31/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.8378 - loss: 0.4694 - val_accuracy: 0.8378 - val_loss: 0.4860 - learning_rate: 0.0010
Epoch 32/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 73ms/step - accuracy: 0.8424 - loss: 0.4558 - val_accuracy: 0.8451 - val_loss: 0.4616 - learning_rate: 0.0010
Epoch 33/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 72ms/step - accuracy: 0.8508 - loss: 0.4329 - val_accuracy: 0.8677 - val_loss: 0.3939 - learning_rate: 2.0000e-04
Epoch 34/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.8550 - loss: 0.4145 - val_accuracy: 0.8676 - val_loss: 0.4048 - learning_rate: 2.0000e-04
Epoch 35/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 31s 78ms/step - accuracy: 0.8600 - loss: 0.4079 - val_accuracy: 0.8739 - val_loss: 0.3824 - learning_rate: 2.0000e-04
Epoch 36/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 33s 84ms/step - accuracy: 0.8654 - loss: 0.3936 - val_accuracy: 0.8750 - val_loss: 0.3796 - learning_rate: 2.0000e-04
Epoch 37/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 31s 80ms/step - accuracy: 0.8649 - loss: 0.3967 - val_accuracy: 0.8711 - val_loss: 0.3834 - learning_rate: 2.0000e-04
Epoch 38/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 32s 83ms/step - accuracy: 0.8656 - loss: 0.3891 - val_accuracy: 0.8717 - val_loss: 0.3969 - learning_rate: 2.0000e-04
Epoch 39/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 32s 81ms/step - accuracy: 0.8656 - loss: 0.3887 - val_accuracy: 0.8764 - val_loss: 0.3759 - learning_rate: 2.0000e-04
Epoch 40/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 30s 77ms/step - accuracy: 0.8680 - loss: 0.3839 - val_accuracy: 0.8770 - val_loss: 0.3748 - learning_rate: 2.0000e-04
Epoch 41/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 31s 80ms/step - accuracy: 0.8660 - loss: 0.3862 - val_accuracy: 0.8744 - val_loss: 0.3792 - learning_rate: 2.0000e-04
Epoch 42/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 30s 76ms/step - accuracy: 0.8717 - loss: 0.3765 - val_accuracy: 0.8770 - val_loss: 0.3684 - learning_rate: 2.0000e-04
Epoch 43/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 31s 80ms/step - accuracy: 0.8678 - loss: 0.3841 - val_accuracy: 0.8774 - val_loss: 0.3687 - learning_rate: 2.0000e-04
Epoch 44/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 30s 76ms/step - accuracy: 0.8696 - loss: 0.3745 - val_accuracy: 0.8770 - val_loss: 0.3686 - learning_rate: 2.0000e-04
Epoch 45/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8687 - loss: 0.3792 - val_accuracy: 0.8724 - val_loss: 0.3866 - learning_rate: 2.0000e-04
Epoch 46/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 30s 75ms/step - accuracy: 0.8690 - loss: 0.3730 - val_accuracy: 0.8729 - val_loss: 0.3892 - learning_rate: 2.0000e-04
Epoch 47/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.8707 - loss: 0.3763 - val_accuracy: 0.8837 - val_loss: 0.3567 - learning_rate: 2.0000e-04
Epoch 48/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 30s 76ms/step - accuracy: 0.8703 - loss: 0.3798 - val_accuracy: 0.8774 - val_loss: 0.3767 - learning_rate: 2.0000e-04
Epoch 49/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8730 - loss: 0.3633 - val_accuracy: 0.8797 - val_loss: 0.3708 - learning_rate: 2.0000e-04
Epoch 50/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8713 - loss: 0.3687 - val_accuracy: 0.8732 - val_loss: 0.3910 - learning_rate: 2.0000e-04
Epoch 51/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 33s 84ms/step - accuracy: 0.8692 - loss: 0.3750 - val_accuracy: 0.8772 - val_loss: 0.3709 - learning_rate: 2.0000e-04
Epoch 52/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 31s 78ms/step - accuracy: 0.8725 - loss: 0.3645 - val_accuracy: 0.8752 - val_loss: 0.3774 - learning_rate: 2.0000e-04
Epoch 53/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 28s 73ms/step - accuracy: 0.8794 - loss: 0.3517 - val_accuracy: 0.8867 - val_loss: 0.3440 - learning_rate: 1.0000e-04
Epoch 54/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 73ms/step - accuracy: 0.8740 - loss: 0.3640 - val_accuracy: 0.8806 - val_loss: 0.3628 - learning_rate: 1.0000e-04
Epoch 55/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 30s 76ms/step - accuracy: 0.8778 - loss: 0.3567 - val_accuracy: 0.8825 - val_loss: 0.3566 - learning_rate: 1.0000e-04
Epoch 56/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 73ms/step - accuracy: 0.8769 - loss: 0.3597 - val_accuracy: 0.8810 - val_loss: 0.3614 - learning_rate: 1.0000e-04
Epoch 57/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8741 - loss: 0.3616 - val_accuracy: 0.8805 - val_loss: 0.3596 - learning_rate: 1.0000e-04
Epoch 58/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 42s 75ms/step - accuracy: 0.8779 - loss: 0.3498 - val_accuracy: 0.8795 - val_loss: 0.3647 - learning_rate: 1.0000e-04
Epoch 59/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.8795 - loss: 0.3512 - val_accuracy: 0.8799 - val_loss: 0.3647 - learning_rate: 1.0000e-04
Epoch 60/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 42s 77ms/step - accuracy: 0.8792 - loss: 0.3554 - val_accuracy: 0.8840 - val_loss: 0.3538 - learning_rate: 1.0000e-04
Epoch 61/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8782 - loss: 0.3521 - val_accuracy: 0.8799 - val_loss: 0.3659 - learning_rate: 1.0000e-04
Epoch 62/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 75ms/step - accuracy: 0.8830 - loss: 0.3409 - val_accuracy: 0.8852 - val_loss: 0.3462 - learning_rate: 1.0000e-04
Epoch 63/100
391/391 ━━━━━━━━━━━━━━━━━━━━ 29s 74ms/step - accuracy: 0.8780 - loss: 0.3476 - val_accuracy: 0.8830 - val_loss: 0.3530 - learning_rate: 1.0000e-04

Evaluation

test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.8892 - loss: 0.3440
Test accuracy: 0.8866999745368958

Not bad for a simple CNN! We can improve the performance later on by tweaking hyperparameters or using more complex architectures.

# summarize history for accuracy
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.plot(history_2.history['accuracy'])
plt.plot(history_2.history['val_accuracy'])
plt.title('model 2 accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

# summarize history for loss
plt.plot(history_2.history['loss'])
plt.plot(history_2.history['val_loss'])
plt.title('model 2 loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

png

png

Prediction

Since we saved the model, let’s trying loading the saved model and predict with a new image.

from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
from google.colab import drive

drive.mount('/content/drive')

# Load the saved model
loaded_model = tf.keras.models.load_model('/content/drive/MyDrive/my_cifar10_model.keras')

# Load and preprocess the single image
image_path = '/content/drive/MyDrive/bird_sample1.jpg'  
img = image.load_img(image_path, target_size=(32, 32))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array / 255.0

prediction = loaded_model.predict(img_array)
predicted_class = np.argmax(prediction, axis=1)[0]

# CIFAR-10 class labels (in order)
class_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Get the predicted class label
predicted_label = class_labels[predicted_class]

# Visualize the image and prediction
plt.imshow(image.load_img(image_path))
plt.title(f"Predicted: {predicted_label}")
plt.axis('off')  # Turn off axis labels
plt.show()

print(f"Predicted class index: {predicted_class}")
print(f"Predicted class label: {predicted_label}")
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step

png

Predicted class index: 2
Predicted class label: bird

Now let’s download a couple of pictures from Google and test out the model.

import os

image_dir = '/content/drive/MyDrive/test_samples_cifar/'  

# Function to load, preprocess, and predict a single image
def predict_image(image_path):
    img = image.load_img(image_path, target_size=(32, 32))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array / 255.0
    prediction = loaded_model.predict(img_array)
    predicted_class = np.argmax(prediction, axis=1)[0]
    return predicted_class

# Process all images in a folder
image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.jpeg', '.png', 'webp'))]

plt.figure(figsize=(15, 5 * (len(image_files) // 3 + 1)))  # Adjust figure size

for i, image_file in enumerate(image_files):
    image_path = os.path.join(image_dir, image_file)
    predicted_class = predict_image(image_path)
    predicted_label = class_labels[predicted_class]

    plt.subplot(len(image_files) // 3 + 1, 3, i + 1)
    plt.imshow(image.load_img(image_path))
    plt.title(f"Predicted: {predicted_label}")
    plt.axis('off')

plt.tight_layout()  # Improve spacing
plt.show()

# Print predictions for each image
for image_file in image_files:
    image_path = os.path.join(image_dir, image_file)
    predicted_class = predict_image(image_path)
    predicted_label = class_labels[predicted_class]
    print(f"Image: {image_file}, Predicted: {predicted_label} (Class Index: {predicted_class})")

png

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 75ms/step
Image: bird_sample1.jpg, Predicted: bird (Class Index: 2)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 51ms/step
Image: bird_sample2.jpg, Predicted: truck (Class Index: 9)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 52ms/step
Image: airplane_sample2.jpg, Predicted: airplane (Class Index: 0)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 50ms/step
Image: bird_sample3.webp, Predicted: bird (Class Index: 2)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 49ms/step
Image: airplane_sample1.webp, Predicted: airplane (Class Index: 0)

Oops. It seems to classify the second picture as a truck… Poor parrots. But you get the idea. Let’s do it again another time.