TensorFlow 2.0 Beta is available Learn more

tf.function and AutoGraph in TensorFlow 2.0

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TF 2.0 brings together the ease of eager execution and the power of TF 1.0. At the center of this merger is tf.function, which allows you to transform a subset of Python syntax into portable, high-performance TensorFlow graphs.

A cool new feature of tf.function is AutoGraph, which lets you write graph code using natural Python syntax. For a list of the Python features that you can use with AutoGraph, see AutoGraph Capabilities and Limitations. For more details about tf.function, see the RFC TF 2.0: Functions, not Sessions. For more details about AutoGraph, see tf.autograph.

This tutorial will walk you through the basic features of tf.function and AutoGraph.

Setup

Import TensorFlow 2.0 Preview Nightly and enable TF 2.0 mode:

from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
!pip install -q tensorflow==2.0.0-beta1
import tensorflow as tf

The tf.function decorator

When you annotate a function with tf.function, you can still call it like any other function. But it will be compiled into a graph, which means you get the benefits of faster execution, running on GPU or TPU, or exporting to SavedModel.

@tf.function
def simple_nn_layer(x, y):
  return tf.nn.relu(tf.matmul(x, y))


x = tf.random.uniform((3, 3))
y = tf.random.uniform((3, 3))

simple_nn_layer(x, y)
<tf.Tensor: id=23, shape=(3, 3), dtype=float32, numpy=
array([[0.6566509 , 1.1459428 , 0.5263159 ],
       [0.38579437, 0.82408667, 0.22760163],
       [0.30971363, 0.52188736, 0.2356692 ]], dtype=float32)>

If we examine the result of the annotation, we can see that it's a special callable that handles all interactions with the TensorFlow runtime.

simple_nn_layer
<tensorflow.python.eager.def_function.Function at 0x7f4d9a04d390>

If your code uses multiple functions, you don't need to annotate them all - any functions called from an annotated function will also run in graph mode.

def linear_layer(x):
  return 2 * x + 1


@tf.function
def deep_net(x):
  return tf.nn.relu(linear_layer(x))


deep_net(tf.constant((1, 2, 3)))
<tf.Tensor: id=36, shape=(3,), dtype=int32, numpy=array([3, 5, 7], dtype=int32)>

Functions can be faster than eager code, for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.2711559560000296
Function conv: 0.2252887610000016
Note how there's not much difference in performance for convolutions
lstm_cell = tf.keras.layers.LSTMCell(10)

@tf.function
def lstm_fn(input, state):
  return lstm_cell(input, state)

input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# warm up
lstm_cell(input, state); lstm_fn(input, state)
print("eager lstm:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("function lstm:", timeit.timeit(lambda: lstm_fn(input, state), number=10))
eager lstm: 0.006397080999931859
function lstm: 0.004643363000013778

Use Python control flow

When using data-dependent control flow inside tf.function, you can use Python control flow statements and AutoGraph will convert them into appropriate TensorFlow ops. For example, if statements will be converted into tf.cond() if they depend on a Tensor.

In the example below, x is a Tensor but the if statement works as expected:

@tf.function
def square_if_positive(x):
  if x > 0:
    x = x * x
  else:
    x = 0
  return x


print('square_if_positive(2) = {}'.format(square_if_positive(tf.constant(2))))
print('square_if_positive(-2) = {}'.format(square_if_positive(tf.constant(-2))))
square_if_positive(2) = 4
square_if_positive(-2) = 0

AutoGraph supports common Python statements like while, for, if, break, continue and return, with support for nesting. That means you can use Tensor expressions in the condition of while and if statements, or iterate over a Tensor in a for loop.

@tf.function
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
      continue
    s += c
  return s


sum_even(tf.constant([10, 12, 15, 20]))
<tf.Tensor: id=606, shape=(), dtype=int32, numpy=42>

AutoGraph also provides a low-level API for advanced users. For example we can use it to have a look at the generated code.

print(tf.autograph.to_code(sum_even.python_function))
def tf__sum_even(items):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  s = 0

  def loop_body(loop_vars, s_2):
    c = loop_vars
    continue_ = False
    cond = c % 2 > 0

    def get_state():
      return ()

    def set_state(_):
      pass

    def if_true():
      continue_ = True
      return continue_

    def if_false():
      return continue_
    continue_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)
    cond_1 = ag__.not_(continue_)

    def get_state_1():
      return ()

    def set_state_1(_):
      pass

    def if_true_1():
      s_1, = s_2,
      s_1 += c
      return s_1

    def if_false_1():
      return s_2
    s_2 = ag__.if_stmt(cond_1, if_true_1, if_false_1, get_state_1, set_state_1)
    return s_2,
  s, = ag__.for_stmt(items, None, loop_body, (s,))
  do_return = True
  retval_ = s
  cond_2 = ag__.is_undefined_return(retval_)

  def get_state_2():
    return ()

  def set_state_2(_):
    pass

  def if_true_2():
    retval_ = None
    return retval_

  def if_false_2():
    return retval_
  retval_ = ag__.if_stmt(cond_2, if_true_2, if_false_2, get_state_2, set_state_2)
  return retval_

Here's an example of more complicated control flow:

@tf.function
def fizzbuzz(n):
  msg = tf.constant('')
  for i in tf.range(n):
    if tf.equal(i % 3, 0):
      tf.print('Fizz')
    elif tf.equal(i % 5, 0):
      tf.print('Buzz')
    else:
      tf.print(i)

fizzbuzz(tf.constant(15))
Fizz
1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14

Keras and AutoGraph

AutoGraph is available by default in non-dynamic Keras models. For more information, see tf.keras.

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      return input_data // 2


model = CustomModel()

model(tf.constant([-2, -4]))
<tf.Tensor: id=723, shape=(2,), dtype=int32, numpy=array([-1, -2], dtype=int32)>

Side effects

Just like in eager mode, you can use operations with side effects, like tf.assign or tf.print normally inside tf.function, and it will insert the necessary control dependencies to ensure they execute in order.

v = tf.Variable(5)

@tf.function
def find_next_odd():
  v.assign(v + 1)
  if tf.equal(v % 2, 0):
    v.assign(v + 1)


find_next_odd()
v
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=7>

Debugging

tf.function and AutoGraph work by generating code and tracing it into TensorFlow graphs. This mechanism does not yet support step-by-step debuggers like pdb. However, you can call tf.config.run_functions_eagerly(True) to temporarily enable eager execution inside the `tf.function' and use your favorite debugger:

@tf.function
def f(x):
  if x > 0:
    # Try setting a breakpoint here!
    # Example:
    #   import pdb
    #   pdb.set_trace()
    x = x + 1
  return x

tf.config.experimental_run_functions_eagerly(True)

# You can now set breakpoints and run the code in a debugger.
f(tf.constant(1))

tf.config.experimental_run_functions_eagerly(False)

Download data

def prepare_mnist_features_and_labels(x, y):
  x = tf.cast(x, tf.float32) / 255.0
  y = tf.cast(y, tf.int64)
  return x, y

def mnist_dataset():
  (x, y), _ = tf.keras.datasets.mnist.load_data()
  ds = tf.data.Dataset.from_tensor_slices((x, y))
  ds = ds.map(prepare_mnist_features_and_labels)
  ds = ds.take(20000).shuffle(20000).batch(100)
  return ds

train_dataset = mnist_dataset()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Define the model

model = tf.keras.Sequential((
    tf.keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(10)))
model.build()
optimizer = tf.keras.optimizers.Adam()

Define the training loop

compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


def train_one_step(model, optimizer, x, y):
  with tf.GradientTape() as tape:
    logits = model(x)
    loss = compute_loss(y, logits)

  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

  compute_accuracy(y, logits)
  return loss


@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if tf.equal(step % 10, 0):
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
Step 10 : loss 1.81350708 ; accuracy 0.36
Step 20 : loss 1.20090556 ; accuracy 0.526
Step 30 : loss 0.871289372 ; accuracy 0.605
Step 40 : loss 0.485336393 ; accuracy 0.66125
Step 50 : loss 0.455551505 ; accuracy 0.7038
Step 60 : loss 0.410298198 ; accuracy 0.730666637
Step 70 : loss 0.241451189 ; accuracy 0.753714263
Step 80 : loss 0.308346 ; accuracy 0.771
Step 90 : loss 0.274941027 ; accuracy 0.786
Step 100 : loss 0.539047658 ; accuracy 0.7969
Step 110 : loss 0.387483239 ; accuracy 0.806272745
Step 120 : loss 0.497935444 ; accuracy 0.814083338
Step 130 : loss 0.318222284 ; accuracy 0.819846153
Step 140 : loss 0.222806469 ; accuracy 0.826071441
Step 150 : loss 0.320912778 ; accuracy 0.8316
Step 160 : loss 0.226743415 ; accuracy 0.837125
Step 170 : loss 0.269675553 ; accuracy 0.841882348
Step 180 : loss 0.149641976 ; accuracy 0.846055567
Step 190 : loss 0.378610432 ; accuracy 0.850631595
Step 200 : loss 0.172260076 ; accuracy 0.8549
Final step tf.Tensor(200, shape=(), dtype=int32) : loss tf.Tensor(0.17226008, shape=(), dtype=float32) ; accuracy tf.Tensor(0.8549, shape=(), dtype=float32)

Batching

In real applications batching is essential for performance. The best code to convert to AutoGraph is code where the control flow is decided at the batch level. If making decisions at the individual example level, try to use batch APIs to maintain performance.

For example, if you have the following code in Python:

def square_if_positive(x):
  return [i ** 2 if i > 0 else i for i in x]


square_if_positive(range(-5, 5))
[-5, -4, -3, -2, -1, 0, 1, 4, 9, 16]

You may be tempted to write it in TensorFlow as such (and this would work!):

@tf.function
def square_if_positive_naive(x):
  result = tf.TensorArray(tf.int32, size=x.shape[0])
  for i in tf.range(x.shape[0]):
    if x[i] > 0:
      result = result.write(i, x[i] ** 2)
    else:
      result = result.write(i, x[i])
  return result.stack()


square_if_positive_naive(tf.range(-5, 5))
<tf.Tensor: id=1840, shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>

But in this case, it turns out you can write the following:

def square_if_positive_vectorized(x):
  return tf.where(x > 0, x ** 2, x)


square_if_positive_vectorized(tf.range(-5, 5))
<tf.Tensor: id=1850, shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>