In this example, we will apply convolutional neural networks (CNNs) to MNIST, a dataset containing images of handwritten digits. This is one of the most well-known datasets in machine learning.
# If we initialize the random number generator before loading Keras, we'll get the same
# result each time we run the notebook.
import numpy as np
np.random.seed(0)
import keras
from keras.datasets import mnist
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.models import Sequential
import matplotlib.pyplot as plt
# To set the color used to display monochrome images.
plt.rcParams['image.cmap'] = 'Blues'
Keras includes a function to load the MNIST data. The function returns two pairs: the input and output data of the training and test set, respectively.
These are all 28x28 grayscale images.
(x_train, y_train), (x_test, y_test) = mnist.load_data()
img_width, img_height = x_train[0].shape
img_width, img_height
Before we can apply the CNN, we need to carry out a few steps of preprocessing.
First, we need to reshape the data a bit into a form that is suitable for the CNN. Currently, each image is represented as a 28x28 matrix. Even though the images are in a grayscale format, we need to convert them to three-dimensional objects (formally, 3-dimensional tensors). The reason is that all convolutions in Keras are applied to 3-dimensional data, even for grayscale images.
This means that the whole training set and test set become four-dimensional. (Number of images x 28 x 28 x 1.)
We use NumPy's reshape
operation to carry out this transformation.
x_train = x_train.reshape(x_train.shape[0], img_width, img_height, 1)
x_test = x_test.reshape(x_test.shape[0], img_width, img_height, 1)
The second transformation is a rescaling of the pixel values. Instead of ranging from 0 to 255, they will now be in the range between 0 and 1.
The reason for this transformation is just that it is easier for the CNN to learn if the features values aren't too large. (Alternatively, we could have initialized the CNN weights to smaller values.) If you don't rescale, the CNN will still learn, just more slowly.
x_train = x_train / 255
x_test = x_test / 255
To exemplify, let's use plt.imshow
to visualize an image. (This function assumes that the pixel values are between 0 and 1, so the rescaling we did above is necessary here as well.)
The notation x_train[12,:,:,0]
means: 12th image in training set, all rows and columns, and first color dimension.
plt.imshow(x_train[12,:,:,0]);
As a final step of preprocessing, we will convert the output labels (digits) into "one-hot" vectors. For instance, the digit 2 will become [0,0,1,0,...,0]
.
This is necessary in Keras when using the softmax output layer and the categorical cross-entropy loss function.
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# For instance, here is how the first instance in the training set is encodes.
y_train[0]
Now we have everything to train the CNN. As in the feedforward NN, we start by creating a Sequential
model. (This means that our classifier consists of layers.)
Then, we alternate convolutional and pooling layers, as is customary in CNN. We use ReLU units in the hidden layers. Finally, we apply a Flatten
, to convert the feature maps after the last pooling step into vectors. The final part of the model looks like a normal feedforward neural network: first, a standard hidden layer using 128 units, and finally the output softmax layer.
We train the model using the Adam optimizer. To keep things simple, we let it train for just one epoch; for real-world problems, there would be many epochs, probably using early stopping to determine when to terminate training.
num_classes = 10
model = Sequential()
model.add(Conv2D(32, kernel_size=(5, 5), strides=(1, 1),
activation='relu',
input_shape=(img_width, img_height, 1)))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, (5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=256,
epochs=1,
verbose=1,
validation_data=(x_test, y_test));
Please note that if you retrain the model, the result might be a slightly different, because of randomness in weight initialization and in the Adam optimizer.
We can evaluate as we did in the previous notebook, using predict_class
and scikit-learn's accuracy_score
, but we can also take a shortcut by calling model.evaluate
.
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Let's exemplify the classifier's output.
We first take a look at an instance from the test set. Let's say for instance that we look at the 100th instance (which happens to represent the number 6):
plt.imshow(x_test[100,:,:,0]);
We compute the predictions for all instances in the test set. As you can see, the model has made a correct prediction for the 100th instance.
guesses = model.predict_classes(x_test)
guesses[100]
Let's also take a look at the probabilistic output. We compute the class probabilities for all instances. This gives us a $N$ x 10 matrix, where $N$ is the number of instances and 10 because there are 10 possible digits (from 0 to 9).
probabilities = model.predict(x_test)
We make a small helper function that sorts the probabilities and prints them in the sorted order.
Again considering the 100th instance, we see that it has been determined to be an image of the digit 6 with a probability very close to 1.
def print_probs(ps):
for p, i in sorted([(p, i) for i, p in enumerate(ps)], reverse=True):
print(f'{i}: {p:.4f}')
print_probs(probabilities[100])
Finally, let's inspect some of the errors.
We gather all misclassified instances in a list. For each misclassified instance, we store the output probability of the guessed class, the probabilities of all classes, the true digit class, and the image.
We sort the list by the output probabilities, so that the first instances in the list are those where the classifier had a (misguided) high confidence in the erroneous predictions.
As you can see, 187 out of the 10,000 test instances were misclassified. [Again, this number may vary if you retrain the model, because of randomness in the training process.]
errors = []
for x, y, g, p in zip(x_test, y_test, guesses, probabilities):
if not y[g]:
errors.append( (p[g], p, y.argmax(), x[:,:,0]) )
errors.sort(reverse=True)
len(errors)
Here is the instance where the classifier was most confident in its incorrect prediction. In this case, this is an instance of the digit 6, which was misclassified as a 4. This instance is quite hard to classify even for the human eye.
[Again, you may see a different result if you retrain.]
def show_error(err):
_, p, label, img = err
print('Correct label:', label)
print_probs(p)
plt.imshow(img)
show_error(errors[0])