|View on TensorFlow.org||Run in Google Colab||View source on GitHub|
Import TensorFlow and the XLA library. XLA contains
xla.compile(), an experimental API that compiles part or all of a model with XLA.
import tensorflow as tf from tensorflow.contrib.compiler import xla
Define some necessary constants and prepare the MNIST dataset.
# Size of each input image, 28 x 28 pixels IMAGE_SIZE = 28 * 28 # Number of distinct number labels, [0..9] NUM_CLASSES = 10 # Number of examples in each training batch (step) TRAIN_BATCH_SIZE = 100 # Number of training steps to run TRAIN_STEPS = 1000
# Loads MNIST dataset. train, test = tf.keras.datasets.mnist.load_data() train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat() test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE) iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes) images, labels = iterator.get_next() images = tf.reshape(images, [-1, IMAGE_SIZE]) images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)
Define the model constructing function
Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.
When called, it returns two values.
y is a
tf.Tensor representing predicted probability of each target class,
train_step is a
tf.Operation that increments
global_step and applies variable update.
def build_mnist_model(x, y_): y = tf.keras.layers.Dense(NUM_CLASSES).apply(x) cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) return y, train_step
xla.compile with the
build_mnist_model function to enable XLA. Following code block wraps the model with
xla.compile(), which allows the target function with provided inputs to be executed by XLA.
[y] = xla.compile(build_mnist_model, inputs=[images, labels])
When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.
xla.compile does not return any
tf.Operation nodes that can be executed independently from the generated XLA ops. Instead, returned
tf.Operation nodes from the target function are added as control dependencies of all returned
tf.Tensor values. This triggers execution of the
tf.Operation nodes when the returned tensors are evaluated.
In pseudo-code, xla.compile's implementation looks as follows:
# Ask Tensorflow to execute code in XLA-friendly manner y, train_step = build_mnist_model(images, labels) with tf.control_dependencies([train_step]): y = tf.identity(y) # Ask Tensorflow to STOP executing code in XLA-friendly manner
xla.compile() always returns a list of
tf.Tensor's (even if there is only one-element).
If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with
sess.run(). At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready.
Train and test the model
# Creates session and initialize all variables. # xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet. sess = tf.Session() sess.run(tf.global_variables_initializer())
Following code block trains model. Evaluating
y also triggers its control dependency node
train_step, which updates model variables.
# Feeds training dataset sess.run(iterator.make_initializer(train_ds)) # Runs TRAIN_STEPS steps for i in range(TRAIN_STEPS): sess.run(y) print("Model trained for %s steps." % TRAIN_STEPS)
Model trained for 1000 steps.
# Tests trained model # Feeds testing dataset sess.run(iterator.make_initializer(test_ds)) # Calculates accuracy correct_prediction = tf.equal(tf.argmax(y, 1), labels) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print("Prediction accuracy after training: %s" % sess.run(accuracy))
Prediction accuracy after training: 0.91
# Cleans up session sess.close()