AutoGraph: Easy control flow for graphs

View on Run in Google Colab View source on GitHub

AutoGraph helps you write complicated graph code using normal Python. Behind the scenes, AutoGraph automatically transforms your code into the equivalent TensorFlow graph code. AutoGraph already supports much of the Python language, and that coverage continues to grow. For a list of supported Python language features, see the Autograph capabilities and limitations.


Import TensorFlow, AutoGraph, and any supporting modules:

from __future__ import division, print_function, absolute_import

import tensorflow as tf
layers = tf.keras.layers
from tensorflow import contrib
autograph = contrib.autograph

import numpy as np
import matplotlib.pyplot as plt

We'll enable eager execution for demonstration purposes, but AutoGraph works in both eager and graph execution environments:


Automatically convert Python control flow

AutoGraph will convert much of the Python language into the equivalent TensorFlow graph building code.

AutoGraph converts a function like:

def square_if_positive(x):
  if x > 0:
    x = x * x
    x = 0.0
  return x

To a function that uses graph building:

from __future__ import print_function
import tensorflow as tf

def tf__square_if_positive(x):
    with ag__.function_scope('square_if_positive'):

      def if_true():
        with ag__.function_scope('if_true'):
          x_1, = x,
          x_1 = x_1 * x_1
          return x_1,

      def if_false():
        with ag__.function_scope('if_false'):
          x_2, = x,
          x_2 = 0.0
          return x_2,
      x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
      return x

tf__square_if_positive.autograph_info__ = {}

Code written for eager execution can run in a tf.Graph with the same results, but with the benfits of graph execution:

print('Eager results: %2.2f, %2.2f' % (square_if_positive(tf.constant(9.0)), 
Eager results: 81.00, 0.00

Generate a graph-version and call it:

tf_square_if_positive = autograph.to_graph(square_if_positive)

with tf.Graph().as_default():  
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  g_out1 = tf_square_if_positive(tf.constant( 9.0))
  g_out2 = tf_square_if_positive(tf.constant(-9.0))
  with tf.Session() as sess:
    print('Graph results: %2.2f, %2.2f\n' % (,
Graph results: 81.00, 0.00

AutoGraph supports common Python statements like while, for, if, break, and return, with support for nesting. Compare this function with the complicated graph verson displayed in the following code blocks:

# Continue in a loop
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
    s += c
  return s

print('Eager result: %d' % sum_even(tf.constant([10,12,15,20])))

tf_sum_even = autograph.to_graph(sum_even)

with tf.Graph().as_default(), tf.Session() as sess:
    print('Graph result: %d\n\n' %[10,12,15,20]))))
Eager result: 42
Graph result: 42

from __future__ import print_function
import tensorflow as tf

def tf__sum_even(items):
    with ag__.function_scope('sum_even'):
      s = 0

      def extra_test(s_2):
        with ag__.function_scope('extra_test'):
          return True

      def loop_body(loop_vars, s_2):
        with ag__.function_scope('loop_body'):
          c = loop_vars
          continue_ = tf.constant(False)

          def if_true():
            with ag__.function_scope('if_true'):
              continue__1, = continue_,
              continue__1 = tf.constant(True)
              return continue__1,

          def if_false():
            with ag__.function_scope('if_false'):
              return continue_,
          continue_ = ag__.utils.run_cond(tf.greater(c % 2, 0), if_true, if_false)

          def if_true_1():
            with ag__.function_scope('if_true_1'):
              s_1, = s_2,
              s_1 += c
              return s_1,

          def if_false_1():
            with ag__.function_scope('if_false_1'):
              return s_2,
          s_2 = ag__.utils.run_cond(tf.logical_not(continue_), if_true_1, if_false_1)
          return s_2,
      s = ag__.for_stmt(items, extra_test, loop_body, (s,))
      return s

tf__sum_even.autograph_info__ = {}


If you don't need easy access to the original Python function, use the convert decorator:

def fizzbuzz(i, n):
  while i < n:
    msg = ''
    if i % 3 == 0:
      msg += 'Fizz'
    if i % 5 == 0:
      msg += 'Buzz'
    if msg == '':
      msg = tf.as_string(i)
    i += 1
  return i

with tf.Graph().as_default():
  final_i = fizzbuzz(tf.constant(10), tf.constant(16))
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  with tf.Session() as sess:



Let's demonstrate some useful Python language features.


AutoGraph automatically converts the Python assert statement into the equivalent tf.Assert code:

def inverse(x):
  assert x != 0.0, 'Do not pass zero!'
  return 1.0 / x

with tf.Graph().as_default(), tf.Session() as sess:
  except tf.errors.InvalidArgumentError as e:
    print('Got error message:\n    %s' % e.message)
Got error message:
    assertion failed: [Do not pass zero!]
     [[node inverse/Assert/Assert (defined at /tmp/  = Assert[T=[DT_STRING], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](inverse/NotEqual, inverse/Assert/Assert/data_0)]]


Use the Python print function in-graph:

def count(n):
  while i < n:
    i += 1
  return n
with tf.Graph().as_default(), tf.Session() as sess:


Append to lists in loops (tensor list ops are automatically created):

def arange(n):
  z = []
  # We ask you to tell us the element dtype of the list
  autograph.set_element_type(z, tf.int32)
  for i in tf.range(n):
  # when you're done with the list, stack it
  # (this is just like np.stack)
  return autograph.stack(z) 

with tf.Graph().as_default(), tf.Session() as sess:

Nested control flow

def nearest_odd_square(x):
  if x > 0:
    x = x * x
    if x % 2 == 0:
      x = x + 1
  return x

with tf.Graph().as_default():  
  with tf.Session() as sess:

While loop

def square_until_stop(x, y):
  while x < y:
    x = x * x
  return x
with tf.Graph().as_default():  
  with tf.Session() as sess:
    print(, tf.constant(100))))

For loop

def squares(nums):

  result = []
  autograph.set_element_type(result, tf.int64)

  for num in nums: 
    result.append(num * num)
  return autograph.stack(result)
with tf.Graph().as_default():  
  with tf.Session() as sess:
[ 0  1  4  9 16 25 36 49 64 81]


def argwhere_cumsum(x, threshold):
  current_sum = 0.0
  idx = 0
  for i in tf.range(len(x)):
    idx = i
    if current_sum >= threshold:
    current_sum += x[i]
  return idx

N = 10
with tf.Graph().as_default():  
  with tf.Session() as sess:
    idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))

Interoperation with tf.Keras

Now that you've seen the basics, let's build some model components with autograph.

It's relatively simple to integrate autograph with tf.keras.

Stateless functions

For stateless functions, like collatz shown below, the easiest way to include them in a keras model is to wrap them up as a layer using tf.keras.layers.Lambda.

import numpy as np

def collatz(x):
  x = tf.reshape(x,())
  assert x > 0
  n = tf.convert_to_tensor((0,)) 
  while x!=1:
    n += 1
    if x%2 == 0:
      x = x // 2
      x = 3 * x + 1
  return n

with tf.Graph().as_default():
  model = tf.keras.Sequential([
    tf.keras.layers.Lambda(collatz, input_shape=(1,), output_shape=())
  result = model.predict(np.array([6171]))

Custom Layers and Models

The easiest way to use AutoGraph with Keras layers and models is to @autograph.convert() the call method. See the TensorFlow Keras guide for details on how to build on these classes.

Here is a simple example of the stochastic network depth technique :

# `K` is used to check if we're in train or test mode.
K = tf.keras.backend

class StochasticNetworkDepth(tf.keras.Sequential):
  def __init__(self, layers, pfirst=1.0, plast=0.5,**kwargs):
    self.pfirst = pfirst
    self.plast = plast
    super(StochasticNetworkDepth, self).__init__(layers,**kwargs)
  def build(self, input_shape):
    self.depth = len(self.layers)
    self.plims = np.linspace(self.pfirst, self.plast, self.depth + 1)[:-1]
    super(StochasticNetworkDepth, self).build(input_shape.as_list())

  def call(self, inputs):
    training = tf.cast(K.learning_phase(), dtype=bool)  
    if not training: 
      count = self.depth
      return super(StochasticNetworkDepth, self).call(inputs), count
    p = tf.random_uniform((self.depth,))
    keeps = (p <= self.plims)
    x = inputs
    count = tf.reduce_sum(tf.cast(keeps, tf.int32))
    for i in range(self.depth):
      if keeps[i]:
        x = self.layers[i](x)
    # return both the final-layer output and the number of layers executed.
    return x, count

Let's try it on mnist-shaped data:

train_batch = np.random.randn(64, 28, 28, 1).astype(np.float32)

Build a simple stack of conv layers, in the stochastic depth model:

with tf.Graph().as_default() as g:
  model = StochasticNetworkDepth(
        layers.Conv2D(filters=16, activation=tf.nn.relu,
                  kernel_size=(3, 3), padding='same')
        for n in range(20)
      pfirst=1.0, plast=0.5
  ), None, None, 1)))
  init = tf.global_variables_initializer()

Now test it to ensure it behaves as expected in train and test modes:

# Use an explicit session here so we can set the train/test switch, and
# inspect the layer count returned by `call`
with tf.Session(graph=g) as sess:
  for phase, name in enumerate(['test','train']):
    result, count = model(tf.convert_to_tensor(train_batch, dtype=tf.float32))

    result1, count1 =, count))
    result2, count2 =, count))

    delta = (result1 - result2)
    print(name, "sum abs delta: ", abs(delta).mean())
    print("    layers 1st call: ", count1)
    print("    layers 2nd call: ", count2)
test sum abs delta:  0.0
    layers 1st call:  20
    layers 2nd call:  20

train sum abs delta:  0.0005258316
    layers 1st call:  15
    layers 2nd call:  14

Advanced example: An in-graph training loop

The previous section showed that AutoGraph can be used inside Keras layers and models. Keras models can also be used in AutoGraph code.

Since writing control flow in AutoGraph is easy, running a training loop in a TensorFlow graph should also be easy.

This example shows how to train a simple Keras model on MNIST with the entire training process—loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence—is performed in-graph.

Download data

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
Downloading data from
11493376/11490434 [==============================] - 0s 0us/step

Define the model

def mlp_model(input_shape):
  model = tf.keras.Sequential((
      tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),
      tf.keras.layers.Dense(100, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')))
  return model

def predict(m, x, y):
  y_p = m(tf.reshape(x, (-1, 28 * 28)))
  losses = tf.keras.losses.categorical_crossentropy(y, y_p)
  l = tf.reduce_mean(losses)
  accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)
  accuracy = tf.reduce_mean(accuracies)
  return l, accuracy

def fit(m, x, y, opt):
  l, accuracy = predict(m, x, y)
  # Autograph automatically adds the necessary <a href="./../api_docs/python/tf/control_dependencies"><code>tf.control_dependencies</code></a> here.
  # (Without them nothing depends on `opt.minimize`, so it doesn't run.)
  # This makes it much more like eager-code.
  return l, accuracy

def setup_mnist_data(is_training, batch_size):
  if is_training:
    ds =, train_labels))
    ds = ds.shuffle(batch_size * 10)
    ds =, test_labels))

  ds = ds.repeat()
  ds = ds.batch(batch_size)
  return ds

def get_next_batch(ds):
  itr = ds.make_one_shot_iterator()
  image, label = itr.get_next()
  x = tf.to_float(image) / 255.0
  y = tf.one_hot(tf.squeeze(label), 10)
  return x, y 

Define the training loop

# Use `recursive = True` to recursively convert functions called by this one.
def train(train_ds, test_ds, hp):
  m = mlp_model((28 * 28,))
  opt = tf.train.AdamOptimizer(hp.learning_rate)
  # We'd like to save our losses to a list. In order for AutoGraph
  # to convert these lists into their graph equivalent,
  # we need to specify the element type of the lists.
  train_losses = []
  autograph.set_element_type(train_losses, tf.float32)
  test_losses = []
  autograph.set_element_type(test_losses, tf.float32)
  train_accuracies = []
  autograph.set_element_type(train_accuracies, tf.float32)
  test_accuracies = []
  autograph.set_element_type(test_accuracies, tf.float32)
  # This entire training loop will be run in-graph.
  i = tf.constant(0)
  while i < hp.max_steps:
    train_x, train_y = get_next_batch(train_ds)
    test_x, test_y = get_next_batch(test_ds)

    step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)
    step_test_loss, step_test_accuracy = predict(m, test_x, test_y)
    if i % 50 == 0:
      print('Step', i, 'train loss:', step_train_loss, 'test loss:',
            step_test_loss, 'train accuracy:', step_train_accuracy,
            'test accuracy:', step_test_accuracy)
    i += 1
  # We've recorded our loss values and accuracies 
  # to a list in a graph with AutoGraph's help.
  # In order to return the values as a Tensor, 
  # we need to stack them before returning them.
  return (autograph.stack(train_losses), autograph.stack(test_losses),  
          autograph.stack(train_accuracies), autograph.stack(test_accuracies))

Now build the graph and run the training loop:

with tf.Graph().as_default() as g:
  hp =
  train_ds = setup_mnist_data(True, 50)
  test_ds = setup_mnist_data(False, 1000)
  (train_losses, test_losses, train_accuracies,
   test_accuracies) = train(train_ds, test_ds, hp)

  init = tf.global_variables_initializer()
with tf.Session(graph=g) as sess:
  (train_losses, test_losses, train_accuracies,
   test_accuracies) =[train_losses, test_losses, train_accuracies,
Step 0 train loss: 2.336302 test loss: 2.3516374 train accuracy: 0.12 test accuracy: 0.091
Step 50 train loss: 0.44578597 test loss: 0.5693618 train accuracy: 0.8 test accuracy: 0.824
Step 100 train loss: 0.28974783 test loss: 0.3529715 train accuracy: 0.86 test accuracy: 0.889
Step 150 train loss: 0.594558 test loss: 0.3318378 train accuracy: 0.86 test accuracy: 0.901
Step 200 train loss: 0.22583099 test loss: 0.29405388 train accuracy: 0.92 test accuracy: 0.906
Step 250 train loss: 0.33764172 test loss: 0.26474905 train accuracy: 0.94 test accuracy: 0.919
Step 300 train loss: 0.15363316 test loss: 0.31644264 train accuracy: 0.98 test accuracy: 0.905
Step 350 train loss: 0.32430214 test loss: 0.22792776 train accuracy: 0.94 test accuracy: 0.928
Step 400 train loss: 0.17822695 test loss: 0.21533121 train accuracy: 0.94 test accuracy: 0.928
Step 450 train loss: 0.4252268 test loss: 0.22602391 train accuracy: 0.82 test accuracy: 0.929
plt.title('MNIST train/test losses')
plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.xlabel('Training step')
plt.title('MNIST train/test accuracies')
plt.plot(train_accuracies, label='train accuracy')
plt.plot(test_accuracies, label='test accuracy')
plt.legend(loc='lower right')
plt.xlabel('Training step')