Sign Language Prediction with MobileNet - Code
• 16 min read
from google.colab import drive
drive.mount('/content/gdrive')
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly&response_type=code Enter your authorization code: ·········· Mounted at /content/gdrive
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.applications import imagenet_utils
from sklearn.metrics import confusion_matrix
import itertools
import os
import shutil
import random
import matplotlib.pyplot as plt
%matplotlib inline
See the associated blog for more description regarding this.
os.chdir('/content/gdrive/My Drive/Sign-Language-Digits-Dataset/Dataset')
if os.path.isdir('train/0/') is False:
os.mkdir('train')
os.mkdir('valid')
os.mkdir('test')
for i in range(0, 10):
shutil.move(f'{i}', 'train')
os.mkdir(f'valid/{i}')
os.mkdir(f'test/{i}')
valid_samples = random.sample(os.listdir(f'train/{i}'), 30)
for j in valid_samples:
shutil.move(f'train/{i}/{j}', f'valid/{i}')
test_samples = random.sample(os.listdir(f'train/{i}'), 5)
for k in test_samples:
shutil.move(f'train/{i}/{k}', f'test/{i}')
os.chdir('../..')
for i in range(0, 10):
assert len(os.listdir(f'/content/gdrive/My Drive/Sign-Language-Digits-Dataset/Dataset/valid/{i}')) == 30
assert len(os.listdir(f'/content/gdrive/My Drive/Sign-Language-Digits-Dataset/Dataset/test/{i}')) == 5
train_path = '/content/gdrive/My Drive/Sign-Language-Digits-Dataset/Dataset/train'
valid_path = '/content/gdrive/My Drive/Sign-Language-Digits-Dataset/Dataset/valid'
test_path = '/content/gdrive/My Drive/Sign-Language-Digits-Dataset/Dataset/test'
train_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
directory=train_path, target_size=(224,224), batch_size=10)
valid_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
directory=valid_path, target_size=(224,224), batch_size=10)
test_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
directory=test_path, target_size=(224,224), batch_size=10, shuffle=False)
Found 1712 images belonging to 10 classes. Found 300 images belonging to 10 classes. Found 50 images belonging to 10 classes.
mobile = tf.keras.applications.mobilenet.MobileNet()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/mobilenet_1_0_224_tf.h5 17227776/17225924 [==============================] - 0s 0us/step
mobile.summary()
Model: "mobilenet_1.00_224" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ conv1_pad (ZeroPadding2D) (None, 225, 225, 3) 0 _________________________________________________________________ conv1 (Conv2D) (None, 112, 112, 32) 864 _________________________________________________________________ conv1_bn (BatchNormalization (None, 112, 112, 32) 128 _________________________________________________________________ conv1_relu (ReLU) (None, 112, 112, 32) 0 _________________________________________________________________ conv_dw_1 (DepthwiseConv2D) (None, 112, 112, 32) 288 _________________________________________________________________ conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32) 128 _________________________________________________________________ conv_dw_1_relu (ReLU) (None, 112, 112, 32) 0 _________________________________________________________________ conv_pw_1 (Conv2D) (None, 112, 112, 64) 2048 _________________________________________________________________ conv_pw_1_bn (BatchNormaliza (None, 112, 112, 64) 256 _________________________________________________________________ conv_pw_1_relu (ReLU) (None, 112, 112, 64) 0 _________________________________________________________________ conv_pad_2 (ZeroPadding2D) (None, 113, 113, 64) 0 _________________________________________________________________ conv_dw_2 (DepthwiseConv2D) (None, 56, 56, 64) 576 _________________________________________________________________ conv_dw_2_bn (BatchNormaliza (None, 56, 56, 64) 256 _________________________________________________________________ conv_dw_2_relu (ReLU) (None, 56, 56, 64) 0 _________________________________________________________________ conv_pw_2 (Conv2D) (None, 56, 56, 128) 8192 _________________________________________________________________ conv_pw_2_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_pw_2_relu (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ conv_dw_3 (DepthwiseConv2D) (None, 56, 56, 128) 1152 _________________________________________________________________ conv_dw_3_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_dw_3_relu (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ conv_pw_3 (Conv2D) (None, 56, 56, 128) 16384 _________________________________________________________________ conv_pw_3_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_pw_3_relu (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ conv_pad_4 (ZeroPadding2D) (None, 57, 57, 128) 0 _________________________________________________________________ conv_dw_4 (DepthwiseConv2D) (None, 28, 28, 128) 1152 _________________________________________________________________ conv_dw_4_bn (BatchNormaliza (None, 28, 28, 128) 512 _________________________________________________________________ conv_dw_4_relu (ReLU) (None, 28, 28, 128) 0 _________________________________________________________________ conv_pw_4 (Conv2D) (None, 28, 28, 256) 32768 _________________________________________________________________ conv_pw_4_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_pw_4_relu (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ conv_dw_5 (DepthwiseConv2D) (None, 28, 28, 256) 2304 _________________________________________________________________ conv_dw_5_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_dw_5_relu (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ conv_pw_5 (Conv2D) (None, 28, 28, 256) 65536 _________________________________________________________________ conv_pw_5_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_pw_5_relu (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ conv_pad_6 (ZeroPadding2D) (None, 29, 29, 256) 0 _________________________________________________________________ conv_dw_6 (DepthwiseConv2D) (None, 14, 14, 256) 2304 _________________________________________________________________ conv_dw_6_bn (BatchNormaliza (None, 14, 14, 256) 1024 _________________________________________________________________ conv_dw_6_relu (ReLU) (None, 14, 14, 256) 0 _________________________________________________________________ conv_pw_6 (Conv2D) (None, 14, 14, 512) 131072 _________________________________________________________________ conv_pw_6_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_6_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_7 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_7_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_7_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_7 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_7_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_7_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_8 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_8_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_8_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_8 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_8_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_8_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_9 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_9_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_9_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_9 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_9_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_9_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_10 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_10_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_10_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_10 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_10_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_10_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_11 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_11_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_11_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_11 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_11_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_11_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pad_12 (ZeroPadding2D) (None, 15, 15, 512) 0 _________________________________________________________________ conv_dw_12 (DepthwiseConv2D) (None, 7, 7, 512) 4608 _________________________________________________________________ conv_dw_12_bn (BatchNormaliz (None, 7, 7, 512) 2048 _________________________________________________________________ conv_dw_12_relu (ReLU) (None, 7, 7, 512) 0 _________________________________________________________________ conv_pw_12 (Conv2D) (None, 7, 7, 1024) 524288 _________________________________________________________________ conv_pw_12_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_pw_12_relu (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ conv_dw_13 (DepthwiseConv2D) (None, 7, 7, 1024) 9216 _________________________________________________________________ conv_dw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_dw_13_relu (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ conv_pw_13 (Conv2D) (None, 7, 7, 1024) 1048576 _________________________________________________________________ conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_pw_13_relu (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 1024) 0 _________________________________________________________________ reshape_1 (Reshape) (None, 1, 1, 1024) 0 _________________________________________________________________ dropout (Dropout) (None, 1, 1, 1024) 0 _________________________________________________________________ conv_preds (Conv2D) (None, 1, 1, 1000) 1025000 _________________________________________________________________ reshape_2 (Reshape) (None, 1000) 0 _________________________________________________________________ predictions (Activation) (None, 1000) 0 ================================================================= Total params: 4,253,864 Trainable params: 4,231,976 Non-trainable params: 21,888 _________________________________________________________________
x = mobile.layers[-6].output
predictions = Dense(10, activation='softmax')(x)
model = Model(inputs=mobile.input, outputs=predictions)
for layer in model.layers[:-23]:
layer.trainable = False
model.summary()
Model: "functional_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ conv1_pad (ZeroPadding2D) (None, 225, 225, 3) 0 _________________________________________________________________ conv1 (Conv2D) (None, 112, 112, 32) 864 _________________________________________________________________ conv1_bn (BatchNormalization (None, 112, 112, 32) 128 _________________________________________________________________ conv1_relu (ReLU) (None, 112, 112, 32) 0 _________________________________________________________________ conv_dw_1 (DepthwiseConv2D) (None, 112, 112, 32) 288 _________________________________________________________________ conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32) 128 _________________________________________________________________ conv_dw_1_relu (ReLU) (None, 112, 112, 32) 0 _________________________________________________________________ conv_pw_1 (Conv2D) (None, 112, 112, 64) 2048 _________________________________________________________________ conv_pw_1_bn (BatchNormaliza (None, 112, 112, 64) 256 _________________________________________________________________ conv_pw_1_relu (ReLU) (None, 112, 112, 64) 0 _________________________________________________________________ conv_pad_2 (ZeroPadding2D) (None, 113, 113, 64) 0 _________________________________________________________________ conv_dw_2 (DepthwiseConv2D) (None, 56, 56, 64) 576 _________________________________________________________________ conv_dw_2_bn (BatchNormaliza (None, 56, 56, 64) 256 _________________________________________________________________ conv_dw_2_relu (ReLU) (None, 56, 56, 64) 0 _________________________________________________________________ conv_pw_2 (Conv2D) (None, 56, 56, 128) 8192 _________________________________________________________________ conv_pw_2_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_pw_2_relu (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ conv_dw_3 (DepthwiseConv2D) (None, 56, 56, 128) 1152 _________________________________________________________________ conv_dw_3_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_dw_3_relu (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ conv_pw_3 (Conv2D) (None, 56, 56, 128) 16384 _________________________________________________________________ conv_pw_3_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_pw_3_relu (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ conv_pad_4 (ZeroPadding2D) (None, 57, 57, 128) 0 _________________________________________________________________ conv_dw_4 (DepthwiseConv2D) (None, 28, 28, 128) 1152 _________________________________________________________________ conv_dw_4_bn (BatchNormaliza (None, 28, 28, 128) 512 _________________________________________________________________ conv_dw_4_relu (ReLU) (None, 28, 28, 128) 0 _________________________________________________________________ conv_pw_4 (Conv2D) (None, 28, 28, 256) 32768 _________________________________________________________________ conv_pw_4_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_pw_4_relu (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ conv_dw_5 (DepthwiseConv2D) (None, 28, 28, 256) 2304 _________________________________________________________________ conv_dw_5_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_dw_5_relu (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ conv_pw_5 (Conv2D) (None, 28, 28, 256) 65536 _________________________________________________________________ conv_pw_5_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_pw_5_relu (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ conv_pad_6 (ZeroPadding2D) (None, 29, 29, 256) 0 _________________________________________________________________ conv_dw_6 (DepthwiseConv2D) (None, 14, 14, 256) 2304 _________________________________________________________________ conv_dw_6_bn (BatchNormaliza (None, 14, 14, 256) 1024 _________________________________________________________________ conv_dw_6_relu (ReLU) (None, 14, 14, 256) 0 _________________________________________________________________ conv_pw_6 (Conv2D) (None, 14, 14, 512) 131072 _________________________________________________________________ conv_pw_6_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_6_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_7 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_7_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_7_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_7 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_7_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_7_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_8 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_8_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_8_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_8 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_8_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_8_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_9 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_9_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_9_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_9 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_9_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_9_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_10 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_10_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_10_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_10 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_10_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_10_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_dw_11 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_11_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_11_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_11 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_11_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_11_relu (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pad_12 (ZeroPadding2D) (None, 15, 15, 512) 0 _________________________________________________________________ conv_dw_12 (DepthwiseConv2D) (None, 7, 7, 512) 4608 _________________________________________________________________ conv_dw_12_bn (BatchNormaliz (None, 7, 7, 512) 2048 _________________________________________________________________ conv_dw_12_relu (ReLU) (None, 7, 7, 512) 0 _________________________________________________________________ conv_pw_12 (Conv2D) (None, 7, 7, 1024) 524288 _________________________________________________________________ conv_pw_12_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_pw_12_relu (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ conv_dw_13 (DepthwiseConv2D) (None, 7, 7, 1024) 9216 _________________________________________________________________ conv_dw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_dw_13_relu (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ conv_pw_13 (Conv2D) (None, 7, 7, 1024) 1048576 _________________________________________________________________ conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_pw_13_relu (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 1024) 0 _________________________________________________________________ dense (Dense) (None, 10) 10250 ================================================================= Total params: 3,239,114 Trainable params: 1,873,930 Non-trainable params: 1,365,184 _________________________________________________________________
model.compile(optimizer=Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
Training
model.fit(x=train_batches, steps_per_epoch=18, validation_data=valid_batches, validation_steps=3, epochs=30, verbose=2)
Epoch 1/30 18/18 - 31s - loss: 0.3288 - accuracy: 0.9302 - val_loss: 0.8580 - val_accuracy: 0.7000 Epoch 2/30 18/18 - 29s - loss: 0.2942 - accuracy: 0.9333 - val_loss: 0.6755 - val_accuracy: 0.7000 Epoch 3/30 18/18 - 21s - loss: 0.2280 - accuracy: 0.9477 - val_loss: 0.4676 - val_accuracy: 0.8000 Epoch 4/30 18/18 - 21s - loss: 0.2050 - accuracy: 0.9722 - val_loss: 0.4056 - val_accuracy: 0.7667 Epoch 5/30 18/18 - 19s - loss: 0.2673 - accuracy: 0.9611 - val_loss: 0.3198 - val_accuracy: 0.9667 Epoch 6/30 18/18 - 17s - loss: 0.1734 - accuracy: 0.9667 - val_loss: 0.0908 - val_accuracy: 1.0000 Epoch 7/30 18/18 - 13s - loss: 0.1649 - accuracy: 0.9500 - val_loss: 0.2169 - val_accuracy: 0.9333 Epoch 8/30 18/18 - 13s - loss: 0.1673 - accuracy: 0.9722 - val_loss: 0.1466 - val_accuracy: 0.9667 Epoch 9/30 18/18 - 9s - loss: 0.1392 - accuracy: 0.9667 - val_loss: 0.1384 - val_accuracy: 0.9333 Epoch 10/30 18/18 - 10s - loss: 0.1067 - accuracy: 0.9833 - val_loss: 0.2402 - val_accuracy: 0.9000 Epoch 11/30 18/18 - 11s - loss: 0.0944 - accuracy: 0.9889 - val_loss: 0.1494 - val_accuracy: 0.9667 Epoch 12/30 18/18 - 8s - loss: 0.0845 - accuracy: 1.0000 - val_loss: 0.0851 - val_accuracy: 1.0000 Epoch 13/30 18/18 - 8s - loss: 0.1113 - accuracy: 0.9778 - val_loss: 0.1685 - val_accuracy: 0.9333 Epoch 14/30 18/18 - 7s - loss: 0.0861 - accuracy: 0.9944 - val_loss: 0.0873 - val_accuracy: 1.0000 Epoch 15/30 18/18 - 7s - loss: 0.0628 - accuracy: 0.9942 - val_loss: 0.1392 - val_accuracy: 0.9667 Epoch 16/30 18/18 - 5s - loss: 0.0990 - accuracy: 0.9826 - val_loss: 0.0796 - val_accuracy: 1.0000 Epoch 17/30 18/18 - 4s - loss: 0.0914 - accuracy: 0.9778 - val_loss: 0.0590 - val_accuracy: 1.0000 Epoch 18/30 18/18 - 6s - loss: 0.0534 - accuracy: 0.9944 - val_loss: 0.1309 - val_accuracy: 0.9667 Epoch 19/30 18/18 - 4s - loss: 0.0457 - accuracy: 1.0000 - val_loss: 0.0316 - val_accuracy: 1.0000 Epoch 20/30 18/18 - 4s - loss: 0.0521 - accuracy: 0.9944 - val_loss: 0.0472 - val_accuracy: 1.0000 Epoch 21/30 18/18 - 3s - loss: 0.0863 - accuracy: 0.9778 - val_loss: 0.0154 - val_accuracy: 1.0000 Epoch 22/30 18/18 - 4s - loss: 0.0452 - accuracy: 1.0000 - val_loss: 0.1097 - val_accuracy: 0.9667 Epoch 23/30 18/18 - 4s - loss: 0.0508 - accuracy: 0.9944 - val_loss: 0.0392 - val_accuracy: 1.0000 Epoch 24/30 18/18 - 7s - loss: 0.0607 - accuracy: 0.9944 - val_loss: 0.0256 - val_accuracy: 1.0000 Epoch 25/30 18/18 - 1s - loss: 0.0449 - accuracy: 0.9944 - val_loss: 0.0444 - val_accuracy: 1.0000 Epoch 26/30 18/18 - 2s - loss: 0.0510 - accuracy: 0.9944 - val_loss: 0.0346 - val_accuracy: 1.0000 Epoch 27/30 18/18 - 1s - loss: 0.0329 - accuracy: 1.0000 - val_loss: 0.0564 - val_accuracy: 1.0000 Epoch 28/30 18/18 - 2s - loss: 0.0312 - accuracy: 1.0000 - val_loss: 0.0276 - val_accuracy: 1.0000 Epoch 29/30 18/18 - 2s - loss: 0.0427 - accuracy: 0.9944 - val_loss: 0.0664 - val_accuracy: 0.9667 Epoch 30/30 18/18 - 1s - loss: 0.0344 - accuracy: 0.9889 - val_loss: 0.0609 - val_accuracy: 1.0000
<tensorflow.python.keras.callbacks.History at 0x7fd299fcbb00>
test_labels = test_batches.classes
predictions = model.predict(x=test_batches, steps=5, verbose=0)
print(predictions)
[[9.99907970e-01 1.81986852e-05 1.14185386e-05 8.31906982e-06 6.76295031e-06 5.26625763e-06 1.50188625e-05 1.29211230e-06 2.20869988e-05 3.54457507e-06] [9.99764740e-01 4.97330475e-05 1.24770275e-04 1.25518391e-05 2.54034444e-06 1.78213008e-06 2.27403270e-05 5.41358190e-07 1.18345215e-05 8.72721012e-06] [9.99943495e-01 1.02443319e-05 4.07669768e-06 2.59417743e-06 2.22588847e-06 1.97174950e-06 4.29789634e-06 2.17091639e-07 3.33743219e-06 2.75765979e-05] [9.97976959e-01 1.30962511e-03 4.67506397e-05 4.97402689e-05 1.01239813e-04 3.53960495e-05 3.83031263e-04 4.16261264e-06 1.39273197e-05 7.91871498e-05] [9.99794781e-01 6.55109470e-05 5.60145490e-05 9.19256217e-06 6.74312616e-07 2.55389045e-06 1.12290845e-05 3.24747077e-07 6.57366490e-06 5.31644619e-05] [5.96624841e-06 9.98959184e-01 9.82636935e-04 4.61979107e-06 1.41901154e-07 3.83149796e-08 2.53536768e-06 1.55028567e-06 3.95551178e-05 3.84179430e-06] [5.77348619e-05 9.97837484e-01 1.93230971e-03 5.29086901e-05 5.60771980e-07 1.32886186e-07 1.02190063e-06 7.83198743e-07 9.52457194e-05 2.19350532e-05] [3.19325909e-06 9.96406376e-01 3.52200959e-03 2.73286496e-05 1.38922232e-07 8.74872441e-08 1.08019390e-06 7.19391437e-07 3.41323175e-05 4.93078778e-06] [8.91397633e-07 9.99843597e-01 6.41833540e-05 9.20376624e-06 5.28325529e-07 2.35107791e-07 4.17812544e-06 7.52144786e-07 6.69909277e-05 9.46024920e-06] [2.11145539e-06 9.99277651e-01 6.38888392e-04 1.61246521e-06 8.37773797e-08 1.33889078e-08 1.29828607e-07 2.36027589e-07 6.06195317e-05 1.86159541e-05] [1.40882321e-06 3.04088953e-05 9.99190271e-01 7.30123953e-04 3.74597857e-06 7.04620806e-09 1.89693892e-05 5.52513166e-06 3.02043100e-06 1.65276870e-05] [4.09996355e-05 1.62411539e-03 9.77828205e-01 8.13836232e-04 3.26748923e-05 2.13361272e-06 1.89798046e-02 5.88272000e-04 2.49287514e-05 6.49938665e-05] [8.36342861e-06 3.36779805e-04 9.93821859e-01 6.53320982e-04 2.57492866e-05 3.32229831e-07 5.12548909e-03 1.22495271e-06 9.05711147e-07 2.60038087e-05] [7.99389818e-06 2.90314310e-05 9.94603097e-01 1.44842779e-03 7.65812001e-05 4.65628233e-07 3.05270427e-03 5.67832671e-04 7.49530227e-05 1.38985488e-04] [3.46522438e-06 4.86433157e-04 9.99253333e-01 1.86141515e-05 8.96118217e-06 4.99642780e-08 2.13554318e-04 4.10330898e-07 1.25275583e-05 2.64599180e-06] [5.68577816e-05 2.32892362e-05 1.61046904e-04 9.97760653e-01 1.23008670e-06 1.79811893e-03 1.02986924e-05 1.30777389e-05 1.74625966e-04 8.17895113e-07] [3.42618478e-05 9.89843393e-05 1.23947440e-02 9.87302780e-01 4.45760946e-07 1.29584689e-04 1.45311124e-05 1.34755164e-05 3.32886202e-06 7.85242901e-06] [2.61722424e-04 2.25970231e-04 5.14747517e-04 9.92549181e-01 9.46297405e-06 5.28588658e-03 5.42270891e-06 1.30143080e-05 1.05099462e-03 8.35767059e-05] [1.31443335e-06 5.72117733e-06 5.44280512e-04 9.99418259e-01 9.58634132e-08 2.24492705e-05 6.64082506e-07 4.97399469e-07 5.72592035e-06 1.02367096e-06] [7.33219276e-05 3.56858254e-05 2.46638752e-04 7.77239680e-01 2.80055178e-06 2.20447138e-01 1.29028433e-06 1.87667436e-04 5.07000514e-05 1.71502493e-03] [8.84616838e-06 1.72287150e-06 5.27130305e-06 6.36111281e-07 9.97579277e-01 1.48030219e-03 8.55612045e-04 6.84547058e-06 2.29201742e-05 3.85024141e-05] [2.87977709e-05 5.12899396e-06 2.86878530e-05 8.23870141e-06 9.94071543e-01 2.25713290e-03 3.38535896e-03 8.99201550e-05 4.08533342e-05 8.42990703e-05] [6.81434831e-05 1.86656780e-06 5.61141642e-04 5.19602145e-05 9.81923401e-01 3.11230286e-03 1.35045694e-02 2.41562775e-05 2.61299167e-04 4.91183600e-04] [6.54945097e-06 3.00924683e-07 2.14229894e-05 2.13994663e-05 9.87928689e-01 9.72454902e-04 1.10056009e-02 2.72963043e-05 6.87222837e-06 9.38202265e-06] [5.10706229e-07 1.95524965e-08 1.11063594e-06 1.30337980e-07 9.99499679e-01 2.78539665e-04 2.16483604e-04 3.37662101e-07 1.16396768e-06 1.94427730e-06] [1.10659057e-05 1.02498973e-07 2.80285690e-07 2.31634567e-06 1.32311597e-01 8.67163122e-01 4.84139746e-05 4.81770039e-05 8.51955556e-05 3.29826580e-04] [3.36566040e-06 2.71774951e-07 1.56815133e-07 1.89455866e-04 1.66468462e-03 9.98027265e-01 4.27184023e-06 1.02549639e-05 1.09216107e-05 8.93709075e-05] [7.34167543e-06 1.22745541e-06 1.68236136e-06 6.73042436e-04 7.15185364e-04 9.98450279e-01 2.56239873e-05 1.11202444e-05 7.43143682e-05 4.00367535e-05] [3.26074020e-04 1.73902168e-04 3.35095356e-05 2.35643201e-02 4.65085953e-02 9.28241014e-01 1.87676211e-04 1.17491370e-04 6.12921838e-04 2.34515188e-04] [1.53196670e-05 2.26287568e-07 6.14817793e-07 3.94643546e-04 7.02123623e-04 9.98710632e-01 1.38914984e-05 2.12150462e-06 1.51932225e-04 8.39336371e-06] [1.36404287e-03 1.38173311e-03 1.00178637e-01 7.36371102e-03 3.09866350e-02 3.71698756e-03 7.34847426e-01 7.62851462e-02 3.38587798e-02 1.00168865e-02] [1.55438215e-06 7.56569671e-06 1.93705852e-03 1.57460931e-06 1.83093990e-03 1.72299588e-06 9.96193767e-01 2.21732480e-05 1.18845605e-06 2.45226647e-06] [1.73982175e-03 2.90282245e-04 2.19090790e-01 3.85022792e-03 9.29833874e-02 3.65315960e-03 6.71241343e-01 5.50376344e-03 9.16578050e-04 7.30645843e-04] [1.66806840e-05 4.53790699e-06 5.32073714e-03 9.55705673e-06 1.46365270e-03 6.23514643e-05 9.92924690e-01 1.48642837e-04 4.59050279e-05 3.24408575e-06] [4.00441844e-04 8.47149786e-05 8.38202331e-03 3.48406640e-04 3.25773731e-02 9.81657766e-04 9.54741716e-01 1.99468597e-03 7.89099722e-05 4.10044391e-04] [1.42006917e-04 6.20908104e-04 7.04250718e-03 5.62605681e-03 1.17695192e-03 2.23223586e-04 4.37225390e-04 5.33069670e-01 4.49340671e-01 2.32081162e-03] [1.46449963e-06 2.21280516e-06 1.03685644e-03 6.59345533e-05 9.58273071e-04 1.10456494e-05 2.16003522e-04 9.95881200e-01 1.73346093e-03 9.36524666e-05] [1.57869945e-04 4.90538543e-04 7.90488794e-02 1.33991949e-02 8.82933266e-04 4.67371312e-04 2.89049931e-03 9.00509000e-01 1.39442889e-03 7.59232207e-04] [2.76261676e-06 3.58992293e-06 4.14676004e-04 4.83516169e-06 6.32872479e-06 1.87327737e-06 6.94003984e-06 9.98579502e-01 9.59948462e-04 1.96528272e-05] [1.15150295e-03 4.81685682e-04 2.58751353e-03 8.08459241e-04 5.00243269e-02 8.77166225e-04 2.17872718e-03 3.35247427e-01 6.00740671e-01 5.90256089e-03] [3.27213456e-06 1.54394394e-04 7.79893016e-05 4.01407488e-05 2.11231927e-05 5.19559671e-06 4.20569876e-07 1.38790475e-03 9.97655511e-01 6.54020463e-04] [2.57532171e-04 8.40583307e-05 1.26030098e-03 5.38015622e-04 1.50475139e-03 1.51385684e-04 1.82284115e-04 8.22080001e-02 9.12585676e-01 1.22801040e-03] [8.28076227e-05 1.24441576e-04 2.39340728e-03 1.02791074e-03 1.78400415e-03 1.56045702e-04 6.67873828e-05 6.77036028e-03 9.84457195e-01 3.13709350e-03] [3.07454411e-06 6.68912617e-06 1.13914311e-05 9.11148436e-06 1.44990317e-05 7.82387178e-06 3.69191838e-07 3.12103220e-04 9.99610364e-01 2.45070169e-05] [1.30470420e-04 1.23008965e-02 2.51298794e-03 1.70998927e-03 2.84471782e-04 4.16153198e-05 7.09045635e-05 1.25928685e-01 8.49215686e-01 7.80425733e-03] [4.43604549e-05 3.70138223e-05 4.58232216e-06 4.78399634e-05 7.25927763e-04 4.10610024e-04 4.00512499e-06 4.73862128e-05 8.23822047e-04 9.97854412e-01] [8.90586991e-04 2.65139650e-04 4.86360630e-03 4.50516725e-03 2.76033883e-03 4.42838715e-03 2.65056326e-04 1.55687588e-03 3.97594035e-01 5.82870781e-01] [3.13964701e-04 2.22514194e-04 2.69785814e-04 3.42328494e-05 9.21714026e-03 8.76264356e-04 5.47206611e-04 1.68818122e-04 4.44504339e-03 9.83905017e-01] [7.43059732e-04 1.87608050e-04 3.47952708e-04 8.38607200e-04 1.96068548e-03 2.17008434e-04 1.08727392e-04 8.90582232e-05 1.37780723e-03 9.94129539e-01] [1.50213436e-05 2.07497487e-05 2.58430646e-05 5.99290024e-06 4.50582767e-04 1.80798364e-04 1.50254700e-05 2.42412762e-06 1.96153924e-04 9.99087453e-01]]
cm = confusion_matrix(y_true=test_labels, y_pred=predictions.argmax(axis=1))
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
cm_plot_labels = ['0','1','2','3','4','5','6','7','8','9']
plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title='Confusion Matrix')
Confusion matrix, without normalization [[5 0 0 0 0 0 0 0 0 0] [0 5 0 0 0 0 0 0 0 0] [0 0 5 0 0 0 0 0 0 0] [0 0 0 5 0 0 0 0 0 0] [0 0 0 0 5 0 0 0 0 0] [0 0 0 0 0 5 0 0 0 0] [0 0 0 0 0 0 5 0 0 0] [0 0 0 0 0 0 0 4 1 0] [0 0 0 0 0 0 0 0 5 0] [0 0 0 0 0 0 0 0 0 5]]