Missed TensorFlow World? Check out the recap. Learn more

Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub

In this colab we train a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

We start by loading and normalizing the dataset using the Keras API:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.is_gpu_available())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
WARNING: Logging before flag parsing goes to stderr.
W0822 15:45:43.189764 140558350198528 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Train on 50000 samples, validate on 10000 samples
50000/50000 [==============================] - 8s 165us/sample - loss: 2.0690 - accuracy: 0.2356 - val_loss: 1.8167 - val_accuracy: 0.3660
Train on 50000 samples, validate on 10000 samples
Epoch 1/25
50000/50000 [==============================] - 4s 71us/sample - loss: 2.1359 - accuracy: 0.2055 - val_loss: 1.8895 - val_accuracy: 0.3431
Epoch 2/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.8133 - accuracy: 0.3435 - val_loss: 1.7066 - val_accuracy: 0.3866
Epoch 3/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.6796 - accuracy: 0.3911 - val_loss: 1.6057 - val_accuracy: 0.4215
Epoch 4/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.6016 - accuracy: 0.4204 - val_loss: 1.5170 - val_accuracy: 0.4541
Epoch 5/25
50000/50000 [==============================] - 3s 68us/sample - loss: 1.5370 - accuracy: 0.4454 - val_loss: 1.4445 - val_accuracy: 0.4893
Epoch 6/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.4825 - accuracy: 0.4666 - val_loss: 1.4415 - val_accuracy: 0.4851
Epoch 7/25
50000/50000 [==============================] - 3s 67us/sample - loss: 1.4349 - accuracy: 0.4846 - val_loss: 1.3473 - val_accuracy: 0.5192
Epoch 8/25
50000/50000 [==============================] - 4s 70us/sample - loss: 1.4002 - accuracy: 0.4989 - val_loss: 1.3864 - val_accuracy: 0.5124
Epoch 9/25
50000/50000 [==============================] - 3s 70us/sample - loss: 1.3657 - accuracy: 0.5144 - val_loss: 1.2864 - val_accuracy: 0.5451
Epoch 10/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.3353 - accuracy: 0.5251 - val_loss: 1.3246 - val_accuracy: 0.5274
Epoch 11/25
50000/50000 [==============================] - 3s 68us/sample - loss: 1.3070 - accuracy: 0.5356 - val_loss: 1.2589 - val_accuracy: 0.5610
Epoch 12/25
50000/50000 [==============================] - 3s 67us/sample - loss: 1.2798 - accuracy: 0.5469 - val_loss: 1.2657 - val_accuracy: 0.5472
Epoch 13/25
50000/50000 [==============================] - 3s 68us/sample - loss: 1.2597 - accuracy: 0.5534 - val_loss: 1.1791 - val_accuracy: 0.5886
Epoch 14/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.2314 - accuracy: 0.5629 - val_loss: 1.1600 - val_accuracy: 0.5905
Epoch 15/25
50000/50000 [==============================] - 3s 70us/sample - loss: 1.2203 - accuracy: 0.5695 - val_loss: 1.1607 - val_accuracy: 0.5956
Epoch 16/25
50000/50000 [==============================] - 3s 68us/sample - loss: 1.1954 - accuracy: 0.5778 - val_loss: 1.1804 - val_accuracy: 0.5862
Epoch 17/25
50000/50000 [==============================] - 3s 68us/sample - loss: 1.1744 - accuracy: 0.5851 - val_loss: 1.1036 - val_accuracy: 0.6126
Epoch 18/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.1574 - accuracy: 0.5931 - val_loss: 1.0941 - val_accuracy: 0.6136
Epoch 19/25
50000/50000 [==============================] - 3s 66us/sample - loss: 1.1367 - accuracy: 0.6003 - val_loss: 1.2117 - val_accuracy: 0.5734
Epoch 20/25
50000/50000 [==============================] - 4s 70us/sample - loss: 1.1224 - accuracy: 0.6063 - val_loss: 1.0510 - val_accuracy: 0.6355
Epoch 21/25
50000/50000 [==============================] - 3s 68us/sample - loss: 1.1040 - accuracy: 0.6104 - val_loss: 1.0527 - val_accuracy: 0.6294
Epoch 22/25
50000/50000 [==============================] - 3s 68us/sample - loss: 1.0890 - accuracy: 0.6162 - val_loss: 1.1197 - val_accuracy: 0.6027
Epoch 23/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.0667 - accuracy: 0.6255 - val_loss: 1.0609 - val_accuracy: 0.6243
Epoch 24/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.0542 - accuracy: 0.6299 - val_loss: 0.9987 - val_accuracy: 0.6526
Epoch 25/25
50000/50000 [==============================] - 3s 69us/sample - loss: 1.0423 - accuracy: 0.6345 - val_loss: 0.9914 - val_accuracy: 0.6546
CPU times: user 1min 45s, sys: 14.6 s, total: 1min 59s
Wall time: 1min 25s
10000/10000 [==============================] - 1s 92us/sample - loss: 0.9914 - accuracy: 0.6546
Test loss: 0.9913934736251832
Test accuracy: 0.6546

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

tf.keras.backend.clear_session() # We need to clear the session to enable JIT in the middle of the program.
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
Train on 50000 samples, validate on 10000 samples
50000/50000 [==============================] - 14s 271us/sample - loss: 2.0335 - accuracy: 0.2544 - val_loss: 1.8362 - val_accuracy: 0.3457
Train on 50000 samples, validate on 10000 samples
Epoch 1/25
50000/50000 [==============================] - 4s 76us/sample - loss: 2.1292 - accuracy: 0.2124 - val_loss: 1.8313 - val_accuracy: 0.3682
Epoch 2/25
50000/50000 [==============================] - 4s 75us/sample - loss: 1.7945 - accuracy: 0.3507 - val_loss: 1.6596 - val_accuracy: 0.4054
Epoch 3/25
50000/50000 [==============================] - 4s 75us/sample - loss: 1.6761 - accuracy: 0.3947 - val_loss: 1.6023 - val_accuracy: 0.4284
Epoch 4/25
50000/50000 [==============================] - 4s 74us/sample - loss: 1.6056 - accuracy: 0.4182 - val_loss: 1.5538 - val_accuracy: 0.4438
Epoch 5/25
50000/50000 [==============================] - 4s 74us/sample - loss: 1.5468 - accuracy: 0.4388 - val_loss: 1.4656 - val_accuracy: 0.4741
Epoch 6/25
50000/50000 [==============================] - 4s 75us/sample - loss: 1.4969 - accuracy: 0.4576 - val_loss: 1.4076 - val_accuracy: 0.4900
Epoch 7/25
50000/50000 [==============================] - 4s 80us/sample - loss: 1.4523 - accuracy: 0.4785 - val_loss: 1.4082 - val_accuracy: 0.5020
Epoch 8/25
50000/50000 [==============================] - 4s 80us/sample - loss: 1.4178 - accuracy: 0.4924 - val_loss: 1.3819 - val_accuracy: 0.5088
Epoch 9/25
50000/50000 [==============================] - 4s 77us/sample - loss: 1.3805 - accuracy: 0.5024 - val_loss: 1.3053 - val_accuracy: 0.5329
Epoch 10/25
50000/50000 [==============================] - 4s 74us/sample - loss: 1.3494 - accuracy: 0.5187 - val_loss: 1.2919 - val_accuracy: 0.5400
Epoch 11/25
50000/50000 [==============================] - 4s 73us/sample - loss: 1.3198 - accuracy: 0.5267 - val_loss: 1.2374 - val_accuracy: 0.5602
Epoch 12/25
50000/50000 [==============================] - 4s 75us/sample - loss: 1.2940 - accuracy: 0.5354 - val_loss: 1.2407 - val_accuracy: 0.5611
Epoch 13/25
50000/50000 [==============================] - 4s 75us/sample - loss: 1.2688 - accuracy: 0.5481 - val_loss: 1.2562 - val_accuracy: 0.5593
Epoch 14/25
50000/50000 [==============================] - 4s 75us/sample - loss: 1.2492 - accuracy: 0.5566 - val_loss: 1.1979 - val_accuracy: 0.5758
Epoch 15/25
50000/50000 [==============================] - 4s 75us/sample - loss: 1.2265 - accuracy: 0.5653 - val_loss: 1.1577 - val_accuracy: 0.5947
Epoch 16/25
50000/50000 [==============================] - 4s 74us/sample - loss: 1.2015 - accuracy: 0.5734 - val_loss: 1.1507 - val_accuracy: 0.5968
Epoch 17/25
50000/50000 [==============================] - 4s 75us/sample - loss: 1.1774 - accuracy: 0.5835 - val_loss: 1.1283 - val_accuracy: 0.6043
Epoch 18/25
50000/50000 [==============================] - 4s 77us/sample - loss: 1.1652 - accuracy: 0.5865 - val_loss: 1.1042 - val_accuracy: 0.6143
Epoch 19/25
50000/50000 [==============================] - 4s 76us/sample - loss: 1.1419 - accuracy: 0.5969 - val_loss: 1.0941 - val_accuracy: 0.6155
Epoch 20/25
50000/50000 [==============================] - 4s 76us/sample - loss: 1.1202 - accuracy: 0.6045 - val_loss: 1.0853 - val_accuracy: 0.6134
Epoch 21/25
50000/50000 [==============================] - 4s 76us/sample - loss: 1.1043 - accuracy: 0.6107 - val_loss: 1.0532 - val_accuracy: 0.6284
Epoch 22/25
50000/50000 [==============================] - 4s 76us/sample - loss: 1.0899 - accuracy: 0.6148 - val_loss: 1.0368 - val_accuracy: 0.6358
Epoch 23/25
50000/50000 [==============================] - 4s 76us/sample - loss: 1.0734 - accuracy: 0.6221 - val_loss: 1.0859 - val_accuracy: 0.6177
Epoch 24/25
50000/50000 [==============================] - 4s 74us/sample - loss: 1.0607 - accuracy: 0.6258 - val_loss: 1.0262 - val_accuracy: 0.6418
Epoch 25/25
50000/50000 [==============================] - 4s 73us/sample - loss: 1.0441 - accuracy: 0.6328 - val_loss: 0.9776 - val_accuracy: 0.6606
CPU times: user 2min 20s, sys: 21.6 s, total: 2min 42s
Wall time: 1min 34s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.