Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

You will load and normalize the dataset using the TensorFlow Datasets (TFDS) API. First, install/upgrade TensorFlow and TFDS:

pip install -U -q tensorflow tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  result = tfds.load('cifar10', batch_size = -1)
  (x_train, y_train) = result['train']['image'],result['train']['label']
  (x_test, y_test) = result['test']['image'],result['test']['label']

  x_train = x_train.numpy().astype('float32') / 256
  x_test = x_test.numpy().astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1713831126.538478    8618 service.cc:145] XLA service 0x7fcd48001dc0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1713831126.538556    8618 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1713831126.538564    8618 service.cc:153]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1713831126.538571    8618 service.cc:153]   StreamExecutor device (2): Tesla T4, Compute Capability 7.5
I0000 00:00:1713831126.538576    8618 service.cc:153]   StreamExecutor device (3): Tesla T4, Compute Capability 7.5
7/196 ━━━━━━━━━━━━━━━━━━━━ 4s 21ms/step - accuracy: 0.0975 - loss: 2.3135
I0000 00:00:1713831133.385201    8618 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
196/196 ━━━━━━━━━━━━━━━━━━━━ 17s 46ms/step - accuracy: 0.1699 - loss: 2.1953 - val_accuracy: 0.3244 - val_loss: 1.9084
Epoch 1/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.1543 - loss: 2.2268 - val_accuracy: 0.3133 - val_loss: 1.9342
Epoch 2/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.3071 - loss: 1.9102 - val_accuracy: 0.3746 - val_loss: 1.7545
Epoch 3/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.3650 - loss: 1.7487 - val_accuracy: 0.4006 - val_loss: 1.6579
Epoch 4/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.3972 - loss: 1.6603 - val_accuracy: 0.4409 - val_loss: 1.5517
Epoch 5/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4190 - loss: 1.5980 - val_accuracy: 0.4323 - val_loss: 1.5570
Epoch 6/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4409 - loss: 1.5415 - val_accuracy: 0.4733 - val_loss: 1.4858
Epoch 7/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4570 - loss: 1.5003 - val_accuracy: 0.5021 - val_loss: 1.3964
Epoch 8/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4738 - loss: 1.4547 - val_accuracy: 0.4938 - val_loss: 1.3903
Epoch 9/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4898 - loss: 1.4147 - val_accuracy: 0.5029 - val_loss: 1.3742
Epoch 10/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5021 - loss: 1.3834 - val_accuracy: 0.5224 - val_loss: 1.3277
Epoch 11/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5129 - loss: 1.3485 - val_accuracy: 0.5415 - val_loss: 1.2864
Epoch 12/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5252 - loss: 1.3278 - val_accuracy: 0.5436 - val_loss: 1.2834
Epoch 13/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5293 - loss: 1.3120 - val_accuracy: 0.5313 - val_loss: 1.3585
Epoch 14/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5444 - loss: 1.2797 - val_accuracy: 0.5781 - val_loss: 1.1945
Epoch 15/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5542 - loss: 1.2521 - val_accuracy: 0.5770 - val_loss: 1.1954
Epoch 16/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5607 - loss: 1.2361 - val_accuracy: 0.5677 - val_loss: 1.2293
Epoch 17/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5717 - loss: 1.2116 - val_accuracy: 0.5725 - val_loss: 1.1966
Epoch 18/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5778 - loss: 1.1943 - val_accuracy: 0.5952 - val_loss: 1.1552
Epoch 19/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5836 - loss: 1.1833 - val_accuracy: 0.6149 - val_loss: 1.1007
Epoch 20/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5935 - loss: 1.1486 - val_accuracy: 0.6104 - val_loss: 1.1072
Epoch 21/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5982 - loss: 1.1439 - val_accuracy: 0.6284 - val_loss: 1.0721
Epoch 22/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.6029 - loss: 1.1235 - val_accuracy: 0.6291 - val_loss: 1.0678
Epoch 23/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.6074 - loss: 1.1109 - val_accuracy: 0.6219 - val_loss: 1.0965
Epoch 24/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6188 - loss: 1.0930 - val_accuracy: 0.6464 - val_loss: 1.0155
Epoch 25/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6220 - loss: 1.0710 - val_accuracy: 0.6421 - val_loss: 1.0260
CPU times: user 1min 20s, sys: 7.94 s, total: 1min 28s
Wall time: 1min 16s
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.6499 - loss: 1.0166
Test loss: 1.0260233879089355
Test accuracy: 0.6420999765396118

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 ━━━━━━━━━━━━━━━━━━━━ 12s 37ms/step - accuracy: 0.1617 - loss: 2.2184 - val_accuracy: 0.3259 - val_loss: 1.9444
Epoch 1/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.1497 - loss: 2.2450 - val_accuracy: 0.3096 - val_loss: 1.9600
Epoch 2/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.2909 - loss: 1.9547 - val_accuracy: 0.3832 - val_loss: 1.7646
Epoch 3/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.3550 - loss: 1.7939 - val_accuracy: 0.4180 - val_loss: 1.6492
Epoch 4/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.3872 - loss: 1.6986 - val_accuracy: 0.4243 - val_loss: 1.6334
Epoch 5/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4151 - loss: 1.6277 - val_accuracy: 0.4520 - val_loss: 1.5540
Epoch 6/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4322 - loss: 1.5696 - val_accuracy: 0.4836 - val_loss: 1.4479
Epoch 7/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4567 - loss: 1.5175 - val_accuracy: 0.4957 - val_loss: 1.4186
Epoch 8/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4676 - loss: 1.4793 - val_accuracy: 0.5045 - val_loss: 1.3905
Epoch 9/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4846 - loss: 1.4421 - val_accuracy: 0.5204 - val_loss: 1.3442
Epoch 10/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4974 - loss: 1.4100 - val_accuracy: 0.5098 - val_loss: 1.3783
Epoch 11/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5041 - loss: 1.3878 - val_accuracy: 0.5427 - val_loss: 1.2956
Epoch 12/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5144 - loss: 1.3562 - val_accuracy: 0.5374 - val_loss: 1.3066
Epoch 13/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5257 - loss: 1.3247 - val_accuracy: 0.5591 - val_loss: 1.2552
Epoch 14/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5332 - loss: 1.3055 - val_accuracy: 0.5428 - val_loss: 1.3233
Epoch 15/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5440 - loss: 1.2828 - val_accuracy: 0.5673 - val_loss: 1.2201
Epoch 16/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5533 - loss: 1.2651 - val_accuracy: 0.5898 - val_loss: 1.1761
Epoch 17/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5659 - loss: 1.2319 - val_accuracy: 0.5831 - val_loss: 1.1876
Epoch 18/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5702 - loss: 1.2190 - val_accuracy: 0.6020 - val_loss: 1.1407
Epoch 19/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5758 - loss: 1.1987 - val_accuracy: 0.5948 - val_loss: 1.1382
Epoch 20/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5881 - loss: 1.1728 - val_accuracy: 0.6161 - val_loss: 1.1033
Epoch 21/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5979 - loss: 1.1506 - val_accuracy: 0.6116 - val_loss: 1.1065
Epoch 22/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5980 - loss: 1.1464 - val_accuracy: 0.6242 - val_loss: 1.0741
Epoch 23/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.6048 - loss: 1.1223 - val_accuracy: 0.6290 - val_loss: 1.0524
Epoch 24/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.6122 - loss: 1.1000 - val_accuracy: 0.6257 - val_loss: 1.0723
Epoch 25/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.6191 - loss: 1.0867 - val_accuracy: 0.6360 - val_loss: 1.0474
CPU times: user 1min 21s, sys: 7.29 s, total: 1min 29s
Wall time: 1min 18s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.