AutoGraph: Easy control flow for graphs

View on Run in Google Colab View source on GitHub

AutoGraph helps you write complicated graph code using normal Python. Behind the scenes, AutoGraph automatically transforms your code into the equivalent TensorFlow graph code. AutoGraph already supports much of the Python language, and that coverage continues to grow. For a list of supported Python language features, see the Autograph capabilities and limitations.


To use AutoGraph, install the latest version of TensorFlow:

! pip install -q -U tf-nightly

Import TensorFlow, AutoGraph, and any supporting modules:

from __future__ import division, print_function, absolute_import

import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.contrib import autograph

import numpy as np
import matplotlib.pyplot as plt

We'll enable eager execution for demonstration purposes, but AutoGraph works in both eager and graph execution environments:


Automatically convert Python control flow

AutoGraph will convert much of the Python language into the equivalent TensorFlow graph building code.

AutoGraph converts a function like:

def square_if_positive(x):
  if x > 0:
    x = x * x
    x = 0.0
  return x

To a function that uses graph building:

from __future__ import print_function
import tensorflow as tf

def tf__square_if_positive(x):
    with tf.name_scope('square_if_positive'):

      def if_true():
        with tf.name_scope('if_true'):
          x_1, = x,
          x_1 = x_1 * x_1
          return x_1,

      def if_false():
        with tf.name_scope('if_false'):
          x_2, = x,
          x_2 = 0.0
          return x_2,
      x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
      return x

tf__square_if_positive.autograph_info__ = {}

Code written for eager execution can run in a tf.Graph with the same results, but with the benfits of graph execution:

print('Eager results: %2.2f, %2.2f' % (square_if_positive(tf.constant(9.0)), 
Eager results: 81.00, 0.00

Generate a graph-version and call it:

tf_square_if_positive = autograph.to_graph(square_if_positive)

with tf.Graph().as_default():  
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  g_out1 = tf_square_if_positive(tf.constant( 9.0))
  g_out2 = tf_square_if_positive(tf.constant(-9.0))
  with tf.Session() as sess:
    print('Graph results: %2.2f, %2.2f\n' % (,
Graph results: 81.00, 0.00

AutoGraph supports common Python statements like while, for, if, break, and return, with support for nesting. Compare this function with the complicated graph verson displayed in the following code blocks:

# Continue in a loop
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
    s += c
  return s

print('Eager result: %d' % sum_even(tf.constant([10,12,15,20])))

tf_sum_even = autograph.to_graph(sum_even)

with tf.Graph().as_default(), tf.Session() as sess:
    print('Graph result: %d\n\n' %[10,12,15,20]))))
Eager result: 42
Graph result: 42

from __future__ import print_function
import tensorflow as tf

def tf__sum_even(items):
    with tf.name_scope('sum_even'):
      s = 0

      def extra_test(s_2):
        with tf.name_scope('extra_test'):
          return True

      def loop_body(c, s_2):
        with tf.name_scope('loop_body'):
          continue_ = tf.constant(False)

          def if_true():
            with tf.name_scope('if_true'):
              continue__1, = continue_,
              continue__1 = tf.constant(True)
              return continue__1,

          def if_false():
            with tf.name_scope('if_false'):
              return continue_,
          continue_ = ag__.utils.run_cond(tf.greater(c % 2, 0), if_true,

          def if_true_1():
            with tf.name_scope('if_true_1'):
              s_1, = s_2,
              s_1 += c
              return s_1,

          def if_false_1():
            with tf.name_scope('if_false_1'):
              return s_2,
          s_2 = ag__.utils.run_cond(tf.logical_not(continue_), if_true_1,
          return s_2,
      s = ag__.for_stmt(items, extra_test, loop_body, (s,))
      return s

tf__sum_even.autograph_info__ = {}


If you don't need easy access to the original Python function, use the convert decorator:

def fizzbuzz(i, n):
  while i < n:
    msg = ''
    if i % 3 == 0:
      msg += 'Fizz'
    if i % 5 == 0:
      msg += 'Buzz'
    if msg == '':
      msg = tf.as_string(i)
    i += 1
  return i

with tf.Graph().as_default():
  final_i = fizzbuzz(tf.constant(10), tf.constant(16))
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  with tf.Session() as sess:



Let's demonstrate some useful Python language features.


AutoGraph automatically converts the Python assert statement into the equivalent tf.Assert code:

def inverse(x):
  assert x != 0.0, 'Do not pass zero!'
  return 1.0 / x

with tf.Graph().as_default(), tf.Session() as sess:
  except tf.errors.InvalidArgumentError as e:
    print('Got error message:\n    %s' % e.message)
Got error message:
    assertion failed: [Do not pass zero!]
     [[{ {node inverse/Assert/Assert}} = Assert[T=[DT_STRING], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](inverse/NotEqual, inverse/Assert/Assert/data_0)]]


Use the Python print function in-graph:

def count(n):
  while i < n:
    i += 1
  return n
with tf.Graph().as_default(), tf.Session() as sess:


Append to lists in loops (tensor list ops are automatically created):

def arange(n):
  z = []
  # We ask you to tell us the element dtype of the list
  autograph.set_element_type(z, tf.int32)
  for i in range(n):
  # when you're done with the list, stack it
  # (this is just like np.stack)
  return autograph.stack(z) 

with tf.Graph().as_default(), tf.Session() as sess:

Nested control flow

def nearest_odd_square(x):
  if x > 0:
    x = x * x
    if x % 2 == 0:
      x = x + 1
  return x

with tf.Graph().as_default():  
  with tf.Session() as sess:

While loop

def square_until_stop(x, y):
  while x < y:
    x = x * x
  return x
with tf.Graph().as_default():  
  with tf.Session() as sess:
    print(, tf.constant(100))))

For loop

def squares(nums):

  result = []
  autograph.set_element_type(result, tf.int64)

  for num in nums: 
    result.append(num * num)
  return autograph.stack(result)
with tf.Graph().as_default():  
  with tf.Session() as sess:
[ 0  1  4  9 16 25 36 49 64 81]


def argwhere_cumsum(x, threshold):
  current_sum = 0.0
  idx = 0
  for i in range(len(x)):
    idx = i
    if current_sum >= threshold:
    current_sum += x[i]
  return idx

N = 10
with tf.Graph().as_default():  
  with tf.Session() as sess:
    idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))

Interoperation with tf.Keras

Now that you've seen the basics, let's build some model components with autograph.

It's relatively simple to integrate autograph with tf.keras.

Stateless functions

For stateless functions, like collatz shown below, the easiest way to include them in a keras model is to wrap them up as a layer uisng tf.keras.layers.Lambda.

import numpy as np

def collatz(x):
  x = tf.reshape(x,())
  assert x > 0
  n = tf.convert_to_tensor((0,)) 
  while not tf.equal(x, 1):
    n += 1
    if tf.equal(x%2, 0):
      x = x // 2
      x = 3 * x + 1
  return n

with tf.Graph().as_default():
  model = tf.keras.Sequential([
    tf.keras.layers.Lambda(collatz, input_shape=(1,), output_shape=())
result = model.predict(np.array([6171]))
array([261], dtype=int32)

Custom Layers and Models

The easiest way to use AutoGraph with Keras layers and models is to @autograph.convert() the call method. See the TensorFlow Keras guide for details on how to build on these classes.

Here is a simple example of the stocastic network depth technique :

# `K` is used to check if we're in train or test mode.
import tensorflow.keras.backend as K

class StocasticNetworkDepth(tf.keras.Sequential):
  def __init__(self, pfirst=1.0, plast=0.5, *args,**kwargs):
    self.pfirst = pfirst
    self.plast = plast
  def build(self,input_shape):
    self.depth = len(self.layers)
    self.plims = np.linspace(self.pfirst, self.plast, self.depth + 1)[:-1]
  def call(self, inputs):
    training = tf.cast(K.learning_phase(), dtype=bool)  
    if not training: 
      count = self.depth
      return super(StocasticNetworkDepth, self).call(inputs), count
    p = tf.random_uniform((self.depth,))
    keeps = (p <= self.plims)
    x = inputs
    count = tf.reduce_sum(tf.cast(keeps, tf.int32))
    for i in range(self.depth):
      if keeps[i]:
        x = self.layers[i](x)
    # return both the final-layer output and the number of layers executed.
    return x, count

Let's try it on mnist-shaped data:

train_batch = np.random.randn(64, 28, 28, 1).astype(np.float32)

Build a simple stack of conv layers, in the stocastic depth model:

with tf.Graph().as_default() as g:
  model = StocasticNetworkDepth(
        pfirst=1.0, plast=0.5)

  for n in range(20):
          layers.Conv2D(filters=16, activation=tf.nn.relu,
                        kernel_size=(3, 3), padding='same')), None, None, 1)))
  init = tf.global_variables_initializer()

Now test it to ensure it behaves as expected in train and test modes:

# Use an explicit session here so we can set the train/test switch, and
# inspect the layer count returned by `call`
with tf.Session(graph=g) as sess:
  for phase, name in enumerate(['test','train']):
    result, count = model(tf.convert_to_tensor(train_batch, dtype=tf.float32))

    result1, count1 =, count))
    result2, count2 =, count))

    delta = (result1 - result2)
    print(name, "sum abs delta: ", abs(delta).mean())
    print("    layers 1st call: ", count1)
    print("    layers 2nd call: ", count2)
test sum abs delta:  0.0
    layers 1st call:  20
    layers 2nd call:  20

train sum abs delta:  0.00089311943
    layers 1st call:  14
    layers 2nd call:  17

Advanced example: An in-graph training loop

The previous section showed that AutoGraph can be used inside Keras layers and models. Keras models can also be used in AutoGraph code.

Since writing control flow in AutoGraph is easy, running a training loop in a TensorFlow graph should also be easy.

This example shows how to train a simple Keras model on MNIST with the entire training process—loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence—is performed in-graph.

Download data

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
Downloading data from
11493376/11490434 [==============================] - 0s 0us/step

Define the model

def mlp_model(input_shape):
  model = tf.keras.Sequential((
      tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),
      tf.keras.layers.Dense(100, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')))
  return model

def predict(m, x, y):
  y_p = m(tf.reshape(x, (-1, 28 * 28)))
  losses = tf.keras.losses.categorical_crossentropy(y, y_p)
  l = tf.reduce_mean(losses)
  accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)
  accuracy = tf.reduce_mean(accuracies)
  return l, accuracy

def fit(m, x, y, opt):
  l, accuracy = predict(m, x, y)
  # Autograph automatically adds the necessary `tf.control_dependencies` here.
  # (Without them nothing depends on `opt.minimize`, so it doesn't run.)
  # This makes it much more like eager-code.
  return l, accuracy

def setup_mnist_data(is_training, batch_size):
  if is_training:
    ds =, train_labels))
    ds = ds.shuffle(batch_size * 10)
    ds =, test_labels))

  ds = ds.repeat()
  ds = ds.batch(batch_size)
  return ds

def get_next_batch(ds):
  itr = ds.make_one_shot_iterator()
  image, label = itr.get_next()
  x = tf.to_float(image) / 255.0
  y = tf.one_hot(tf.squeeze(label), 10)
  return x, y 

Define the training loop

# Use `recursive = True` to recursively convert functions called by this one.
def train(train_ds, test_ds, hp):
  m = mlp_model((28 * 28,))
  opt = tf.train.AdamOptimizer(hp.learning_rate)
  # We'd like to save our losses to a list. In order for AutoGraph
  # to convert these lists into their graph equivalent,
  # we need to specify the element type of the lists.
  train_losses = []
  autograph.set_element_type(train_losses, tf.float32)
  test_losses = []
  autograph.set_element_type(test_losses, tf.float32)
  train_accuracies = []
  autograph.set_element_type(train_accuracies, tf.float32)
  test_accuracies = []
  autograph.set_element_type(test_accuracies, tf.float32)
  # This entire training loop will be run in-graph.
  i = tf.constant(0)
  while i < hp.max_steps:
    train_x, train_y = get_next_batch(train_ds)
    test_x, test_y = get_next_batch(test_ds)

    step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)
    step_test_loss, step_test_accuracy = predict(m, test_x, test_y)
    if i % (hp.max_steps // 10) == 0:
      print('Step', i, 'train loss:', step_train_loss, 'test loss:',
            step_test_loss, 'train accuracy:', step_train_accuracy,
            'test accuracy:', step_test_accuracy)
    i += 1
  # We've recorded our loss values and accuracies 
  # to a list in a graph with AutoGraph's help.
  # In order to return the values as a Tensor, 
  # we need to stack them before returning them.
  return (autograph.stack(train_losses), autograph.stack(test_losses),  
          autograph.stack(train_accuracies), autograph.stack(test_accuracies))

Now build the graph and run the training loop:

with tf.Graph().as_default() as g:
  hp =
  train_ds = setup_mnist_data(True, 50)
  test_ds = setup_mnist_data(False, 1000)
  (train_losses, test_losses, train_accuracies,
   test_accuracies) = train(train_ds, test_ds, hp)

  init = tf.global_variables_initializer()
with tf.Session(graph=g) as sess:
  (train_losses, test_losses, train_accuracies,
   test_accuracies) =[train_losses, test_losses, train_accuracies,
plt.title('MNIST train/test losses')
plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.xlabel('Training step')
plt.title('MNIST train/test accuracies')
plt.plot(train_accuracies, label='train accuracy')
plt.plot(test_accuracies, label='test accuracy')
plt.legend(loc='lower right')
plt.xlabel('Training step')
Step 0 train loss: 2.3754308 test loss: 2.3728185 train accuracy: 0.08 test accuracy: 0.124
Step 50 train loss: 0.2660943 test loss: 0.41135326 train accuracy: 0.88 test accuracy: 0.875
Step 100 train loss: 0.45980245 test loss: 0.33724102 train accuracy: 0.84 test accuracy: 0.892
Step 150 train loss: 0.28456023 test loss: 0.32270715 train accuracy: 0.92 test accuracy: 0.895
Step 200 train loss: 0.1482509 test loss: 0.2729213 train accuracy: 0.94 test accuracy: 0.915
Step 250 train loss: 0.36422098 test loss: 0.32462135 train accuracy: 0.88 test accuracy: 0.901
Step 300 train loss: 0.15080144 test loss: 0.26890406 train accuracy: 0.92 test accuracy: 0.916
Step 350 train loss: 0.20911835 test loss: 0.18777144 train accuracy: 0.94 test accuracy: 0.934
Step 400 train loss: 0.3479176 test loss: 0.21253592 train accuracy: 0.92 test accuracy: 0.929
Step 450 train loss: 0.14176667 test loss: 0.1943164 train accuracy: 0.96 test accuracy: 0.947