The XLA compile API

View on TensorFlow.org Run in Google Colab View source on GitHub

Import TensorFlow and the XLA library. XLA contains xla.compile(), an experimental API that compiles part or all of a model with XLA.

import tensorflow as tf

from tensorflow.contrib.compiler import xla

Define some necessary constants and prepare the MNIST dataset.

# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000
# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()
test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)

iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)
images, labels = iterator.get_next()
images = tf.reshape(images, [-1, IMAGE_SIZE])
images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Define the model constructing function

Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.

When called, it returns two values. y is a tf.Tensor representing predicted probability of each target class, train_step is a tf.Operation that increments global_step and applies variable update.

def build_mnist_model(x, y_):
  y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)

  cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  return y, train_step

Enable XLA

Use xla.compile with the build_mnist_model function to enable XLA. Following code block wraps the model with xla.compile(), which allows the target function with provided inputs to be executed by XLA.

[y] = xla.compile(build_mnist_model, inputs=[images, labels])

When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.

xla.compile does not return any tf.Operation nodes that can be executed independently from the generated XLA ops. Instead, returned tf.Operation nodes from the target function are added as control dependencies of all returned tf.Tensor values. This triggers execution of the tf.Operation nodes when the returned tensors are evaluated.

In pseudo-code, xla.compile's implementation looks as follows:


# Ask Tensorflow to execute code in XLA-friendly manner

y, train_step = build_mnist_model(images, labels)
with tf.control_dependencies([train_step]):
  y = tf.identity(y)

# Ask Tensorflow to STOP executing code in XLA-friendly manner

xla.compile() always returns a list of tf.Tensor's (even if there is only one-element).

If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with sess.run(). At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready.

Train and test the model

# Creates session and initialize all variables.
# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.
sess = tf.Session()
sess.run(tf.global_variables_initializer())

Following code block trains model. Evaluating y also triggers its control dependency node train_step, which updates model variables.

# Feeds training dataset
sess.run(iterator.make_initializer(train_ds))

# Runs TRAIN_STEPS steps
for i in range(TRAIN_STEPS):
  sess.run(y)

print("Model trained for %s steps." % TRAIN_STEPS)
Model trained for 1000 steps.
# Tests trained model

# Feeds testing dataset
sess.run(iterator.make_initializer(test_ds))

# Calculates accuracy
correct_prediction = tf.equal(tf.argmax(y, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % sess.run(accuracy))
Prediction accuracy after training: 0.93
# Cleans up session
sess.close()