Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

Load and normalize the dataset using the Keras API:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.is_gpu_available())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
WARNING:tensorflow:From <ipython-input-1-f643ac1e83e4>:4: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
Train on 50000 samples, validate on 10000 samples
50000/50000 [==============================] - 4s 90us/sample - loss: 2.0808 - accuracy: 0.2302 - val_loss: 1.8674 - val_accuracy: 0.3461
Train on 50000 samples, validate on 10000 samples
Epoch 1/25
50000/50000 [==============================] - 2s 35us/sample - loss: 2.1491 - accuracy: 0.1997 - val_loss: 2.0613 - val_accuracy: 0.2572
Epoch 2/25
50000/50000 [==============================] - 2s 35us/sample - loss: 1.8468 - accuracy: 0.3308 - val_loss: 1.7397 - val_accuracy: 0.3791
Epoch 3/25
50000/50000 [==============================] - 2s 35us/sample - loss: 1.6979 - accuracy: 0.3894 - val_loss: 1.5846 - val_accuracy: 0.4401
Epoch 4/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.6075 - accuracy: 0.4176 - val_loss: 1.5340 - val_accuracy: 0.4474
Epoch 5/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.5409 - accuracy: 0.4439 - val_loss: 1.4419 - val_accuracy: 0.4886
Epoch 6/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.4855 - accuracy: 0.4625 - val_loss: 1.4144 - val_accuracy: 0.4994
Epoch 7/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.4455 - accuracy: 0.4806 - val_loss: 1.4269 - val_accuracy: 0.4913
Epoch 8/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.4085 - accuracy: 0.4947 - val_loss: 1.3319 - val_accuracy: 0.5294
Epoch 9/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.3755 - accuracy: 0.5095 - val_loss: 1.2937 - val_accuracy: 0.5382
Epoch 10/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.3413 - accuracy: 0.5219 - val_loss: 1.2743 - val_accuracy: 0.5508
Epoch 11/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.3138 - accuracy: 0.5301 - val_loss: 1.2206 - val_accuracy: 0.5699
Epoch 12/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.2889 - accuracy: 0.5430 - val_loss: 1.1959 - val_accuracy: 0.5805
Epoch 13/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.2631 - accuracy: 0.5509 - val_loss: 1.1805 - val_accuracy: 0.5860
Epoch 14/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.2386 - accuracy: 0.5640 - val_loss: 1.2180 - val_accuracy: 0.5807
Epoch 15/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.2164 - accuracy: 0.5705 - val_loss: 1.1592 - val_accuracy: 0.5911
Epoch 16/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.1875 - accuracy: 0.5802 - val_loss: 1.1598 - val_accuracy: 0.5912
Epoch 17/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.1698 - accuracy: 0.5890 - val_loss: 1.1413 - val_accuracy: 0.5988
Epoch 18/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.1513 - accuracy: 0.5935 - val_loss: 1.1149 - val_accuracy: 0.6136
Epoch 19/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.1283 - accuracy: 0.6018 - val_loss: 1.0692 - val_accuracy: 0.6293
Epoch 20/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.1106 - accuracy: 0.6102 - val_loss: 1.0452 - val_accuracy: 0.6369
Epoch 21/25
50000/50000 [==============================] - 2s 34us/sample - loss: 1.0906 - accuracy: 0.6168 - val_loss: 1.0645 - val_accuracy: 0.6312
Epoch 22/25
50000/50000 [==============================] - 2s 35us/sample - loss: 1.0736 - accuracy: 0.6242 - val_loss: 1.0260 - val_accuracy: 0.6409
Epoch 23/25
50000/50000 [==============================] - 2s 35us/sample - loss: 1.0563 - accuracy: 0.6292 - val_loss: 1.0159 - val_accuracy: 0.6446
Epoch 24/25
50000/50000 [==============================] - 2s 35us/sample - loss: 1.0440 - accuracy: 0.6367 - val_loss: 1.0641 - val_accuracy: 0.6294
Epoch 25/25
50000/50000 [==============================] - 2s 35us/sample - loss: 1.0264 - accuracy: 0.6390 - val_loss: 0.9824 - val_accuracy: 0.6595
CPU times: user 50.2 s, sys: 10.9 s, total: 1min 1s
Wall time: 43.8 s
10000/10000 [==============================] - 1s 73us/sample - loss: 0.9824 - accuracy: 0.6595
Test loss: 0.9823672799110412
Test accuracy: 0.6595

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
Train on 50000 samples, validate on 10000 samples
50000/50000 [==============================] - 6s 118us/sample - loss: 2.0425 - accuracy: 0.2453 - val_loss: 1.8272 - val_accuracy: 0.3654
Train on 50000 samples, validate on 10000 samples
Epoch 1/25
50000/50000 [==============================] - 4s 79us/sample - loss: 2.1062 - accuracy: 0.2158 - val_loss: 1.8913 - val_accuracy: 0.3381
Epoch 2/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.8195 - accuracy: 0.3406 - val_loss: 1.6869 - val_accuracy: 0.3991
Epoch 3/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.6950 - accuracy: 0.3860 - val_loss: 1.5977 - val_accuracy: 0.4331
Epoch 4/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.6188 - accuracy: 0.4150 - val_loss: 1.5398 - val_accuracy: 0.4523
Epoch 5/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.5594 - accuracy: 0.4339 - val_loss: 1.4975 - val_accuracy: 0.4646
Epoch 6/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.5061 - accuracy: 0.4594 - val_loss: 1.4205 - val_accuracy: 0.4857
Epoch 7/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.4575 - accuracy: 0.4757 - val_loss: 1.4506 - val_accuracy: 0.4906
Epoch 8/25
50000/50000 [==============================] - 2s 31us/sample - loss: 1.4251 - accuracy: 0.4882 - val_loss: 1.3481 - val_accuracy: 0.5209
Epoch 9/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.3895 - accuracy: 0.5013 - val_loss: 1.3603 - val_accuracy: 0.5171
Epoch 10/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.3567 - accuracy: 0.5141 - val_loss: 1.2936 - val_accuracy: 0.5453
Epoch 11/25
50000/50000 [==============================] - 2s 33us/sample - loss: 1.3300 - accuracy: 0.5258 - val_loss: 1.2902 - val_accuracy: 0.5393
Epoch 12/25
50000/50000 [==============================] - 2s 33us/sample - loss: 1.3018 - accuracy: 0.5365 - val_loss: 1.2281 - val_accuracy: 0.5647
Epoch 13/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.2738 - accuracy: 0.5472 - val_loss: 1.2972 - val_accuracy: 0.5343
Epoch 14/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.2517 - accuracy: 0.5554 - val_loss: 1.2117 - val_accuracy: 0.5723
Epoch 15/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.2307 - accuracy: 0.5646 - val_loss: 1.2038 - val_accuracy: 0.5715
Epoch 16/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.2073 - accuracy: 0.5726 - val_loss: 1.2325 - val_accuracy: 0.5620
Epoch 17/25
50000/50000 [==============================] - 2s 31us/sample - loss: 1.1868 - accuracy: 0.5815 - val_loss: 1.1011 - val_accuracy: 0.6153
Epoch 18/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.1652 - accuracy: 0.5886 - val_loss: 1.1278 - val_accuracy: 0.6063
Epoch 19/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.1409 - accuracy: 0.5963 - val_loss: 1.1364 - val_accuracy: 0.6007
Epoch 20/25
50000/50000 [==============================] - 2s 32us/sample - loss: 1.1261 - accuracy: 0.6050 - val_loss: 1.0738 - val_accuracy: 0.6224
Epoch 21/25
50000/50000 [==============================] - 2s 33us/sample - loss: 1.1066 - accuracy: 0.6092 - val_loss: 1.0841 - val_accuracy: 0.6199
Epoch 22/25
50000/50000 [==============================] - 2s 33us/sample - loss: 1.0875 - accuracy: 0.6185 - val_loss: 1.0398 - val_accuracy: 0.6349
Epoch 23/25
50000/50000 [==============================] - 2s 35us/sample - loss: 1.0689 - accuracy: 0.6241 - val_loss: 1.0206 - val_accuracy: 0.6413
Epoch 24/25
50000/50000 [==============================] - 2s 33us/sample - loss: 1.0524 - accuracy: 0.6331 - val_loss: 1.0460 - val_accuracy: 0.6339
Epoch 25/25
50000/50000 [==============================] - 2s 31us/sample - loss: 1.0371 - accuracy: 0.6372 - val_loss: 1.0547 - val_accuracy: 0.6283
CPU times: user 49 s, sys: 7.26 s, total: 56.3 s
Wall time: 43.7 s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.