Attend the Women in ML Symposium on December 7 Register now

Classifying CIFAR-10 with XLA

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

Load and normalize the dataset using the TensorFlow Datasets API:

pip install tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
2022-10-19 11:11:53.623555: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-10-19 11:11:53.623661: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-10-19 11:11:53.623671: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  result = tfds.load('cifar10', batch_size = -1)
  (x_train, y_train) = result['train']['image'],result['train']['label']
  (x_test, y_test) = result['test']['image'],result['test']['label']

  x_train = x_train.numpy().astype('float32') / 256
  x_test = x_test.numpy().astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
2022-10-19 11:12:05.194702: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
196/196 [==============================] - 10s 17ms/step - loss: 2.0134 - accuracy: 0.2611 - val_loss: 1.7906 - val_accuracy: 0.3709
Epoch 1/25
196/196 [==============================] - 3s 16ms/step - loss: 2.0860 - accuracy: 0.2269 - val_loss: 1.9230 - val_accuracy: 0.2984
Epoch 2/25
196/196 [==============================] - 3s 14ms/step - loss: 1.7848 - accuracy: 0.3574 - val_loss: 1.6553 - val_accuracy: 0.4095
Epoch 3/25
196/196 [==============================] - 3s 14ms/step - loss: 1.6743 - accuracy: 0.3940 - val_loss: 1.5836 - val_accuracy: 0.4405
Epoch 4/25
196/196 [==============================] - 3s 14ms/step - loss: 1.5913 - accuracy: 0.4246 - val_loss: 1.4758 - val_accuracy: 0.4786
Epoch 5/25
196/196 [==============================] - 3s 14ms/step - loss: 1.5331 - accuracy: 0.4455 - val_loss: 1.4230 - val_accuracy: 0.4891
Epoch 6/25
196/196 [==============================] - 3s 14ms/step - loss: 1.4793 - accuracy: 0.4651 - val_loss: 1.4022 - val_accuracy: 0.4907
Epoch 7/25
196/196 [==============================] - 3s 15ms/step - loss: 1.4358 - accuracy: 0.4824 - val_loss: 1.3449 - val_accuracy: 0.5177
Epoch 8/25
196/196 [==============================] - 3s 14ms/step - loss: 1.3958 - accuracy: 0.4983 - val_loss: 1.3449 - val_accuracy: 0.5217
Epoch 9/25
196/196 [==============================] - 3s 14ms/step - loss: 1.3602 - accuracy: 0.5133 - val_loss: 1.2772 - val_accuracy: 0.5473
Epoch 10/25
196/196 [==============================] - 3s 14ms/step - loss: 1.3312 - accuracy: 0.5244 - val_loss: 1.2398 - val_accuracy: 0.5595
Epoch 11/25
196/196 [==============================] - 3s 14ms/step - loss: 1.3007 - accuracy: 0.5370 - val_loss: 1.2546 - val_accuracy: 0.5527
Epoch 12/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2766 - accuracy: 0.5448 - val_loss: 1.2044 - val_accuracy: 0.5770
Epoch 13/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2483 - accuracy: 0.5579 - val_loss: 1.2035 - val_accuracy: 0.5810
Epoch 14/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2295 - accuracy: 0.5630 - val_loss: 1.1886 - val_accuracy: 0.5853
Epoch 15/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2034 - accuracy: 0.5745 - val_loss: 1.2080 - val_accuracy: 0.5705
Epoch 16/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1851 - accuracy: 0.5796 - val_loss: 1.1166 - val_accuracy: 0.6123
Epoch 17/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1636 - accuracy: 0.5898 - val_loss: 1.1214 - val_accuracy: 0.6068
Epoch 18/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1479 - accuracy: 0.5946 - val_loss: 1.0858 - val_accuracy: 0.6180
Epoch 19/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1287 - accuracy: 0.6037 - val_loss: 1.0527 - val_accuracy: 0.6308
Epoch 20/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1108 - accuracy: 0.6097 - val_loss: 1.0911 - val_accuracy: 0.6198
Epoch 21/25
196/196 [==============================] - 3s 14ms/step - loss: 1.0929 - accuracy: 0.6153 - val_loss: 1.0733 - val_accuracy: 0.6225
Epoch 22/25
196/196 [==============================] - 3s 14ms/step - loss: 1.0772 - accuracy: 0.6220 - val_loss: 1.0110 - val_accuracy: 0.6454
Epoch 23/25
196/196 [==============================] - 3s 14ms/step - loss: 1.0615 - accuracy: 0.6288 - val_loss: 1.0078 - val_accuracy: 0.6434
Epoch 24/25
196/196 [==============================] - 3s 14ms/step - loss: 1.0462 - accuracy: 0.6343 - val_loss: 1.0154 - val_accuracy: 0.6413
Epoch 25/25
196/196 [==============================] - 3s 14ms/step - loss: 1.0262 - accuracy: 0.6390 - val_loss: 0.9890 - val_accuracy: 0.6535
CPU times: user 1min 15s, sys: 8.56 s, total: 1min 24s
Wall time: 1min 11s
313/313 [==============================] - 1s 3ms/step - loss: 0.9890 - accuracy: 0.6535
Test loss: 0.9889576435089111
Test accuracy: 0.6535000205039978

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
2022-10-19 11:13:33.679960: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
196/196 [==============================] - 7s 18ms/step - loss: 2.0619 - accuracy: 0.2406 - val_loss: 1.8361 - val_accuracy: 0.3625
Epoch 1/25
196/196 [==============================] - 5s 24ms/step - loss: 2.1244 - accuracy: 0.2116 - val_loss: 1.8907 - val_accuracy: 0.3260
Epoch 2/25
196/196 [==============================] - 2s 12ms/step - loss: 1.8235 - accuracy: 0.3409 - val_loss: 1.6939 - val_accuracy: 0.4016
Epoch 3/25
196/196 [==============================] - 2s 12ms/step - loss: 1.6964 - accuracy: 0.3892 - val_loss: 1.6055 - val_accuracy: 0.4274
Epoch 4/25
196/196 [==============================] - 2s 12ms/step - loss: 1.6105 - accuracy: 0.4183 - val_loss: 1.5264 - val_accuracy: 0.4523
Epoch 5/25
196/196 [==============================] - 2s 12ms/step - loss: 1.5428 - accuracy: 0.4423 - val_loss: 1.4466 - val_accuracy: 0.4833
Epoch 6/25
196/196 [==============================] - 2s 12ms/step - loss: 1.4896 - accuracy: 0.4635 - val_loss: 1.3995 - val_accuracy: 0.4985
Epoch 7/25
196/196 [==============================] - 2s 12ms/step - loss: 1.4418 - accuracy: 0.4807 - val_loss: 1.3730 - val_accuracy: 0.5095
Epoch 8/25
196/196 [==============================] - 2s 12ms/step - loss: 1.4047 - accuracy: 0.4954 - val_loss: 1.4248 - val_accuracy: 0.4890
Epoch 9/25
196/196 [==============================] - 2s 12ms/step - loss: 1.3756 - accuracy: 0.5084 - val_loss: 1.3405 - val_accuracy: 0.5216
Epoch 10/25
196/196 [==============================] - 2s 12ms/step - loss: 1.3413 - accuracy: 0.5184 - val_loss: 1.3007 - val_accuracy: 0.5301
Epoch 11/25
196/196 [==============================] - 2s 12ms/step - loss: 1.3123 - accuracy: 0.5335 - val_loss: 1.2714 - val_accuracy: 0.5450
Epoch 12/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2840 - accuracy: 0.5421 - val_loss: 1.2179 - val_accuracy: 0.5659
Epoch 13/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2604 - accuracy: 0.5515 - val_loss: 1.2110 - val_accuracy: 0.5764
Epoch 14/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2352 - accuracy: 0.5627 - val_loss: 1.1681 - val_accuracy: 0.5877
Epoch 15/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2139 - accuracy: 0.5708 - val_loss: 1.1691 - val_accuracy: 0.5849
Epoch 16/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1900 - accuracy: 0.5800 - val_loss: 1.1291 - val_accuracy: 0.6039
Epoch 17/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1642 - accuracy: 0.5906 - val_loss: 1.1361 - val_accuracy: 0.5965
Epoch 18/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1476 - accuracy: 0.5970 - val_loss: 1.0760 - val_accuracy: 0.6213
Epoch 19/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1275 - accuracy: 0.6037 - val_loss: 1.0685 - val_accuracy: 0.6294
Epoch 20/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1054 - accuracy: 0.6101 - val_loss: 1.0396 - val_accuracy: 0.6365
Epoch 21/25
196/196 [==============================] - 2s 12ms/step - loss: 1.0893 - accuracy: 0.6178 - val_loss: 1.0624 - val_accuracy: 0.6264
Epoch 22/25
196/196 [==============================] - 2s 12ms/step - loss: 1.0788 - accuracy: 0.6221 - val_loss: 1.0447 - val_accuracy: 0.6351
Epoch 23/25
196/196 [==============================] - 2s 12ms/step - loss: 1.0591 - accuracy: 0.6295 - val_loss: 1.0713 - val_accuracy: 0.6274
Epoch 24/25
196/196 [==============================] - 2s 12ms/step - loss: 1.0441 - accuracy: 0.6361 - val_loss: 1.0098 - val_accuracy: 0.6490
Epoch 25/25
196/196 [==============================] - 2s 12ms/step - loss: 1.0248 - accuracy: 0.6409 - val_loss: 0.9878 - val_accuracy: 0.6554
CPU times: user 43 s, sys: 8.53 s, total: 51.5 s
Wall time: 1min 1s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.