AutoGraph: Easy control flow for graphs

View on Run in Google Colab View source on GitHub

AutoGraph helps you write complicated graph code using normal Python. Behind the scenes, AutoGraph automatically transforms your code into the equivalent TensorFlow graph code. AutoGraph already supports much of the Python language, and that coverage continues to grow. For a list of supported Python language features, see the Autograph capabilities and limitations.


To use AutoGraph, install the latest version of TensorFlow:

! pip install -q -U tf-nightly

Import TensorFlow, AutoGraph, and any supporting modules:

from __future__ import division, print_function, absolute_import

import tensorflow as tf
layers = tf.keras.layers
from tensorflow.contrib import autograph

import numpy as np
import matplotlib.pyplot as plt

We'll enable eager execution for demonstration purposes, but AutoGraph works in both eager and graph execution environments:


Automatically convert Python control flow

AutoGraph will convert much of the Python language into the equivalent TensorFlow graph building code.

AutoGraph converts a function like:

def square_if_positive(x):
  if x > 0:
    x = x * x
    x = 0.0
  return x

To a function that uses graph building:

from __future__ import print_function
import tensorflow as tf

def tf__square_if_positive(x):
    with tf.name_scope('square_if_positive'):

      def if_true():
        with tf.name_scope('if_true'):
          x_1, = x,
          x_1 = x_1 * x_1
          return x_1,

      def if_false():
        with tf.name_scope('if_false'):
          x_2, = x,
          x_2 = 0.0
          return x_2,
      x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
      return x

tf__square_if_positive.autograph_info__ = {}

Code written for eager execution can run in a tf.Graph with the same results, but with the benfits of graph execution:

print('Eager results: %2.2f, %2.2f' % (square_if_positive(tf.constant(9.0)), 
Eager results: 81.00, 0.00

Generate a graph-version and call it:

tf_square_if_positive = autograph.to_graph(square_if_positive)

with tf.Graph().as_default():  
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  g_out1 = tf_square_if_positive(tf.constant( 9.0))
  g_out2 = tf_square_if_positive(tf.constant(-9.0))
  with tf.Session() as sess:
    print('Graph results: %2.2f, %2.2f\n' % (,
Graph results: 81.00, 0.00

AutoGraph supports common Python statements like while, for, if, break, and return, with support for nesting. Compare this function with the complicated graph verson displayed in the following code blocks:

# Continue in a loop
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
    s += c
  return s

print('Eager result: %d' % sum_even(tf.constant([10,12,15,20])))

tf_sum_even = autograph.to_graph(sum_even)

with tf.Graph().as_default(), tf.Session() as sess:
    print('Graph result: %d\n\n' %[10,12,15,20]))))
Eager result: 42
Graph result: 42

from __future__ import print_function
import tensorflow as tf

def tf__sum_even(items):
    with tf.name_scope('sum_even'):
      s = 0

      def extra_test(s_2):
        with tf.name_scope('extra_test'):
          return True

      def loop_body(loop_vars, s_2):
        with tf.name_scope('loop_body'):
          c = loop_vars
          continue_ = tf.constant(False)

          def if_true():
            with tf.name_scope('if_true'):
              continue__1, = continue_,
              continue__1 = tf.constant(True)
              return continue__1,

          def if_false():
            with tf.name_scope('if_false'):
              return continue_,
          continue_ = ag__.utils.run_cond(tf.greater(c % 2, 0), if_true, if_false)

          def if_true_1():
            with tf.name_scope('if_true_1'):
              s_1, = s_2,
              s_1 += c
              return s_1,

          def if_false_1():
            with tf.name_scope('if_false_1'):
              return s_2,
          s_2 = ag__.utils.run_cond(tf.logical_not(continue_), if_true_1, if_false_1)
          return s_2,
      s = ag__.for_stmt(items, extra_test, loop_body, (s,))
      return s

tf__sum_even.autograph_info__ = {}


If you don't need easy access to the original Python function, use the convert decorator:

def fizzbuzz(i, n):
  while i < n:
    msg = ''
    if i % 3 == 0:
      msg += 'Fizz'
    if i % 5 == 0:
      msg += 'Buzz'
    if msg == '':
      msg = tf.as_string(i)
    i += 1
  return i

with tf.Graph().as_default():
  final_i = fizzbuzz(tf.constant(10), tf.constant(16))
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  with tf.Session() as sess:



Let's demonstrate some useful Python language features.


AutoGraph automatically converts the Python assert statement into the equivalent tf.Assert code:

def inverse(x):
  assert x != 0.0, 'Do not pass zero!'
  return 1.0 / x

with tf.Graph().as_default(), tf.Session() as sess:
  except tf.errors.InvalidArgumentError as e:
    print('Got error message:\n    %s' % e.message)
Got error message:
    assertion failed: [Do not pass zero!]
     [[node inverse/Assert/Assert (defined at /tmp/  = Assert[T=[DT_STRING], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](inverse/NotEqual, inverse/Assert/Assert/data_0)]]


Use the Python print function in-graph:

def count(n):
  while i < n:
    i += 1
  return n
with tf.Graph().as_default(), tf.Session() as sess:


Append to lists in loops (tensor list ops are automatically created):

def arange(n):
  z = []
  # We ask you to tell us the element dtype of the list
  autograph.set_element_type(z, tf.int32)
  for i in tf.range(n):
  # when you're done with the list, stack it
  # (this is just like np.stack)
  return autograph.stack(z) 

with tf.Graph().as_default(), tf.Session() as sess:

Nested control flow

def nearest_odd_square(x):
  if x > 0:
    x = x * x
    if x % 2 == 0:
      x = x + 1
  return x

with tf.Graph().as_default():  
  with tf.Session() as sess:

While loop

def square_until_stop(x, y):
  while x < y:
    x = x * x
  return x
with tf.Graph().as_default():  
  with tf.Session() as sess:
    print(, tf.constant(100))))

For loop

def squares(nums):

  result = []
  autograph.set_element_type(result, tf.int64)

  for num in nums: 
    result.append(num * num)
  return autograph.stack(result)
with tf.Graph().as_default():  
  with tf.Session() as sess:
[ 0  1  4  9 16 25 36 49 64 81]


def argwhere_cumsum(x, threshold):
  current_sum = 0.0
  idx = 0
  for i in tf.range(len(x)):
    idx = i
    if current_sum >= threshold:
    current_sum += x[i]
  return idx

N = 10
with tf.Graph().as_default():  
  with tf.Session() as sess:
    idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))

Interoperation with tf.Keras

Now that you've seen the basics, let's build some model components with autograph.

It's relatively simple to integrate autograph with tf.keras.

Stateless functions

For stateless functions, like collatz shown below, the easiest way to include them in a keras model is to wrap them up as a layer using tf.keras.layers.Lambda.

import numpy as np

def collatz(x):
  x = tf.reshape(x,())
  assert x > 0
  n = tf.convert_to_tensor((0,)) 
  while not tf.equal(x, 1):
    n += 1
    if tf.equal(x%2, 0):
      x = x // 2
      x = 3 * x + 1
  return n

with tf.Graph().as_default():
  model = tf.keras.Sequential([
    tf.keras.layers.Lambda(collatz, input_shape=(1,), output_shape=())
result = model.predict(np.array([6171]))
array([261], dtype=int32)

Custom Layers and Models

The easiest way to use AutoGraph with Keras layers and models is to @autograph.convert() the call method. See the TensorFlow Keras guide for details on how to build on these classes.

Here is a simple example of the stochastic network depth technique :

# `K` is used to check if we're in train or test mode.
K = tf.keras.backend

class StochasticNetworkDepth(tf.keras.Sequential):
  def __init__(self, pfirst=1.0, plast=0.5, *args,**kwargs):
    self.pfirst = pfirst
    self.plast = plast
  def build(self,input_shape):
    self.depth = len(self.layers)
    self.plims = np.linspace(self.pfirst, self.plast, self.depth + 1)[:-1]
  def call(self, inputs):
    training = tf.cast(K.learning_phase(), dtype=bool)  
    if not training: 
      count = self.depth
      return super(StochasticNetworkDepth, self).call(inputs), count
    p = tf.random_uniform((self.depth,))
    keeps = (p <= self.plims)
    x = inputs
    count = tf.reduce_sum(tf.cast(keeps, tf.int32))
    for i in range(self.depth):
      if keeps[i]:
        x = self.layers[i](x)
    # return both the final-layer output and the number of layers executed.
    return x, count

Let's try it on mnist-shaped data:

train_batch = np.random.randn(64, 28, 28, 1).astype(np.float32)

Build a simple stack of conv layers, in the stochastic depth model:

with tf.Graph().as_default() as g:
  model = StochasticNetworkDepth(
        pfirst=1.0, plast=0.5)

  for n in range(20):
          layers.Conv2D(filters=16, activation=tf.nn.relu,
                        kernel_size=(3, 3), padding='same')), None, None, 1)))
  init = tf.global_variables_initializer()

Now test it to ensure it behaves as expected in train and test modes:

# Use an explicit session here so we can set the train/test switch, and
# inspect the layer count returned by `call`
with tf.Session(graph=g) as sess:
  for phase, name in enumerate(['test','train']):
    result, count = model(tf.convert_to_tensor(train_batch, dtype=tf.float32))

    result1, count1 =, count))
    result2, count2 =, count))

    delta = (result1 - result2)
    print(name, "sum abs delta: ", abs(delta).mean())
    print("    layers 1st call: ", count1)
    print("    layers 2nd call: ", count2)
test sum abs delta:  0.0
    layers 1st call:  20
    layers 2nd call:  20

train sum abs delta:  0.0004654545
    layers 1st call:  17
    layers 2nd call:  15

Advanced example: An in-graph training loop

The previous section showed that AutoGraph can be used inside Keras layers and models. Keras models can also be used in AutoGraph code.

Since writing control flow in AutoGraph is easy, running a training loop in a TensorFlow graph should also be easy.

This example shows how to train a simple Keras model on MNIST with the entire training process—loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence—is performed in-graph.

Download data

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
Downloading data from
11493376/11490434 [==============================] - 0s 0us/step

Define the model

def mlp_model(input_shape):
  model = tf.keras.Sequential((
      tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),
      tf.keras.layers.Dense(100, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')))
  return model

def predict(m, x, y):
  y_p = m(tf.reshape(x, (-1, 28 * 28)))
  losses = tf.keras.losses.categorical_crossentropy(y, y_p)
  l = tf.reduce_mean(losses)
  accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)
  accuracy = tf.reduce_mean(accuracies)
  return l, accuracy

def fit(m, x, y, opt):
  l, accuracy = predict(m, x, y)
  # Autograph automatically adds the necessary <a href="./../api_docs/python/tf/control_dependencies"><code>tf.control_dependencies</code></a> here.
  # (Without them nothing depends on `opt.minimize`, so it doesn't run.)
  # This makes it much more like eager-code.
  return l, accuracy

def setup_mnist_data(is_training, batch_size):
  if is_training:
    ds =, train_labels))
    ds = ds.shuffle(batch_size * 10)
    ds =, test_labels))

  ds = ds.repeat()
  ds = ds.batch(batch_size)
  return ds

def get_next_batch(ds):
  itr = ds.make_one_shot_iterator()
  image, label = itr.get_next()
  x = tf.to_float(image) / 255.0
  y = tf.one_hot(tf.squeeze(label), 10)
  return x, y 

Define the training loop

# Use `recursive = True` to recursively convert functions called by this one.
def train(train_ds, test_ds, hp):
  m = mlp_model((28 * 28,))
  opt = tf.train.AdamOptimizer(hp.learning_rate)
  # We'd like to save our losses to a list. In order for AutoGraph
  # to convert these lists into their graph equivalent,
  # we need to specify the element type of the lists.
  train_losses = []
  autograph.set_element_type(train_losses, tf.float32)
  test_losses = []
  autograph.set_element_type(test_losses, tf.float32)
  train_accuracies = []
  autograph.set_element_type(train_accuracies, tf.float32)
  test_accuracies = []
  autograph.set_element_type(test_accuracies, tf.float32)
  # This entire training loop will be run in-graph.
  i = tf.constant(0)
  while i < hp.max_steps:
    train_x, train_y = get_next_batch(train_ds)
    test_x, test_y = get_next_batch(test_ds)

    step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)
    step_test_loss, step_test_accuracy = predict(m, test_x, test_y)
    if i % (hp.max_steps // 10) == 0:
      print('Step', i, 'train loss:', step_train_loss, 'test loss:',
            step_test_loss, 'train accuracy:', step_train_accuracy,
            'test accuracy:', step_test_accuracy)
    i += 1
  # We've recorded our loss values and accuracies 
  # to a list in a graph with AutoGraph's help.
  # In order to return the values as a Tensor, 
  # we need to stack them before returning them.
  return (autograph.stack(train_losses), autograph.stack(test_losses),  
          autograph.stack(train_accuracies), autograph.stack(test_accuracies))

Now build the graph and run the training loop:

with tf.Graph().as_default() as g:
  hp =
  train_ds = setup_mnist_data(True, 50)
  test_ds = setup_mnist_data(False, 1000)
  (train_losses, test_losses, train_accuracies,
   test_accuracies) = train(train_ds, test_ds, hp)

  init = tf.global_variables_initializer()
with tf.Session(graph=g) as sess:
  (train_losses, test_losses, train_accuracies,
   test_accuracies) =[train_losses, test_losses, train_accuracies,
plt.title('MNIST train/test losses')
plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.xlabel('Training step')
plt.title('MNIST train/test accuracies')
plt.plot(train_accuracies, label='train accuracy')
plt.plot(test_accuracies, label='test accuracy')
plt.legend(loc='lower right')
plt.xlabel('Training step')
Step 0 train loss: 2.359655 test loss: 2.3189092 train accuracy: 0.06 test accuracy: 0.105
Step 50 train loss: 0.59539515 test loss: 0.57915956 train accuracy: 0.78 test accuracy: 0.825
Step 100 train loss: 0.59692603 test loss: 0.33623025 train accuracy: 0.86 test accuracy: 0.89
Step 150 train loss: 0.37055424 test loss: 0.30959117 train accuracy: 0.88 test accuracy: 0.903
Step 200 train loss: 0.24066256 test loss: 0.306852 train accuracy: 0.94 test accuracy: 0.897
Step 250 train loss: 0.34461474 test loss: 0.27985114 train accuracy: 0.9 test accuracy: 0.914
Step 300 train loss: 0.2871255 test loss: 0.2508909 train accuracy: 0.92 test accuracy: 0.927
Step 350 train loss: 0.14871798 test loss: 0.21972342 train accuracy: 0.96 test accuracy: 0.927
Step 400 train loss: 0.34599018 test loss: 0.22240749 train accuracy: 0.88 test accuracy: 0.925
Step 450 train loss: 0.23274632 test loss: 0.2049714 train accuracy: 0.92 test accuracy: 0.935