Better performance with tf.function

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In TensorFlow 2.0, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier and faster), but this can come at the expense of performance and deployability.

To get peak performance and to make your model deployable anywhere, use tf.function to make graphs out of your programs. Thanks to AutoGraph, a surprising amount of Python code just works with tf.function, but there are still pitfalls to be wary of.

The main takeaways and recommendations are:

  • Don't rely on Python side effects like object mutation or list appends.
  • tf.function works best with TensorFlow ops, rather than NumPy ops or Python primitives.
  • When in doubt, use the for x in y idiom.
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}: {}'.format(error_class, e))
  except Exception as e:
    print('Got unexpected exception \n  {}: {}'.format(type(e), e))
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

A tf.function you define is just like a core TensorFlow operation: You can execute it eagerly; you can use it in a graph; it has gradients; and so on.

# A function is like an op

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: id=14, shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
# Functions have gradients

@tf.function
def add(a, b):
  return a + b

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: id=38, shape=(), dtype=float32, numpy=1.0>
# You can use functions inside functions

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: id=64, shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Tracing and polymorphism

Python's dynamic typing means that you can call functions with a variety of argument types, and Python will do something different in each scenario.

On the other hand, TensorFlow graphs require static dtypes and shape dimensions. tf.function bridges this gap by retracing the function when necessary to generate the correct graphs. Most of the subtlety of tf.function usage stems from this retracing behavior.

You can call a function with arguments of different types to see what is happening.

# Functions are polymorphic

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

To control the tracing behavior, use the following techniques:

  • Create a new tf.function. Separate tf.function objects are guaranteed not to share traces.
  • Use the get_concrete_function method to get a specific trace
  • Specify input_signature when calling tf.function to trace only once per calling graph.
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Obtaining concrete trace
Tracing with Tensor("a:0", dtype=string)
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
Using a concrete trace with incompatible types will throw an error
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute __inference_double_91 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_91]
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

When to retrace?

A polymorphic tf.function keeps a cache of concrete functions generated by tracing. The cache keys are effectively tuples of keys generated from the function args and kwargs. The key generated for a tf.Tensor argument is its shape and type. The key generated for a Python primitive is its value. For all other Python types, the keys are based on the object id() so that methods are traced independently for each instance of a class. In the future, TensorFlow may add more sophisticated caching for Python objects that can be safely converted to tensors.

Python or Tensor args?

Often, Python arguments are used to control hyperparameters and graph constructions - for example, num_layers=10 or training=True or nonlinearity='relu'. So if the Python argument changes, it makes sense that you'd have to retrace the graph.

However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so this is a bit inefficient.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_one_step()

train(num_steps=10)
train(num_steps=20)
Tracing with num_steps = 10
Tracing with num_steps = 20

The simple workaround here is to cast your arguments to Tensors if they do not affect the shape of the generated graph.

train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32)

Side effects in tf.function

In general, Python side effects (like printing or mutating objects) only happen during tracing. So how can you reliably trigger side effects from tf.function?

The general rule of thumb is to only use Python side effects to debug your traces. Otherwise, TensorFlow ops like tf.Variable.assign, tf.print, and tf.summary are the best way to ensure your code will be traced and executed by the TensorFlow runtime with each call. In general using a functional style will yield the best results.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

If you would like to execute Python code during each invocation of a tf.function, tf.py_function is an exit hatch. The drawback of tf.py_function is that it's not portable or particularly performant, nor does it work well in distributed (multi-GPU, TPU) setups. Also, since tf.py_function has to be wired into the graph, it casts all inputs/outputs to tensors.

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1
Python side effect
Python side effect
Python side effect

Beware of Python state

Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in Eager mode, many unexpected things can happen inside a tf.function due to tracing behavior.

To give one example, advancing iterator state is a Python side effect and therefore only happens during tracing.

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

If an iterator is generated and consumed entirely within the tf.function, then it should work correctly. However, the entire iterator is probably being traced, which can lead to a giant graph. This may be what you want. But if you're training on an large in-memory dataset represented as a Python list, then this can generate a very large graph, and tf.function is unlikely to yield a speedup.

If you want to iterate over Python data, the safest way is to wrap it in a tf.data.Dataset and use the for x in y idiom. AutoGraph has special support for safely converting for loops when y is a tensor or tf.data.Dataset.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1)]) contains 8 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 5 nodes in its graph
train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 5 nodes in its graph

When wrapping Python/Numpy data in a Dataset, be mindful of tf.data.Dataset.from_generator versus tf.data.Dataset.from_tensors. The former will keep the data in Python and fetch it via tf.py_function which can have performance implications, whereas the latter will bundle a copy of the data as one large tf.constant() node in the graph, which can have memory implications.

Reading data from files via TFRecordDataset/CsvDataset/etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python.

Automatic Control Dependencies

A very appealing property of functions as the programming model, over a general dataflow graph, is that functions can give the runtime more information about what was the intended behavior of the code.

For example, when writing code which has multiple reads and writes to the same variables, a dataflow graph might not naturally encode the originally intended order of operations. In tf.function, we resolve ambiguities in execution order by referring to the execution order of statements in the original Python code. This way, ordering of stateful operations in a tf.function replicates the semantics of Eager mode.

This means there's no need to add manual control dependencies; tf.function is smart enough to add the minimal set of necessary and sufficient control dependencies for your code to run correctly.

# Automatic control dependencies

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0
<tf.Tensor: id=418, shape=(), dtype=float32, numpy=10.0>

Variables

We can use the same idea of leveraging the intended execution order of the code to make variable creation and utilization very easy in tf.function. There is one very important caveat, though, which is that with variables it's possible to write code which behaves differently in eager mode and graph mode.

Specifically, this will happen when you create a new Variable with each call. Due to tracing semantics, tf.function will reuse the same variable each call, but eager mode will create a new variable with each call. To guard against this mistake, tf.function will raise an error if it detects dangerous variable creation behavior.

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Caught expected exception 
  <class 'ValueError'>: in converted code:

    <ipython-input-17-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/variables.py:65 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/eager/def_function.py:413 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

# Non-ambiguous code is ok though

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
# You can also create variables inside a tf.function as long as we can prove
# that those variables are created only the first time the function is executed.

class C: pass
obj = C(); obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
# Variable initializers can depend on function arguments and on values of other
# variables. We can figure out the right initialization order using the same
# method we use to generate control dependencies.

state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))
tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)

Using AutoGraph

The autograph library is fully integrated with tf.function, and it will rewrite conditionals and loops which depend on Tensors to run dynamically in the graph.

tf.cond and tf.while_loop continue to work with tf.function, but code with control flow is often easier to write and understand when written in imperative style.

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.991675735 0.770152 0.265438318 0.926479578 0.270098567]
[0.758075953 0.647017837 0.259375095 0.728948355 0.263716549]
[0.639942288 0.569659 0.253710955 0.622421563 0.257768452]
[0.564860284 0.515108824 0.248403832 0.552811861 0.2522071]
[0.511574626 0.473916173 0.24341765 0.50262475 0.246992245]
[0.471171141 0.441358089 0.238721251 0.46417886 0.242089257]
[0.439145088 0.41476953 0.23428756 0.433484 0.237468168]
[0.412935585 0.392514914 0.230092898 0.408228815 0.233102903]
[0.390962422 0.373526275 0.226116508 0.386967748 0.228970662]
[0.372189641 0.357072234 0.222340047 0.368743211 0.225051373]
[0.355905473 0.342632562 0.218747273 0.352891922 0.22132732]
[0.341602385 0.32982561 0.215323746 0.33893773 0.217782795]
[0.328907162 0.318364054 0.212056547 0.326528698 0.214403793]
[0.31753847 0.30802694 0.208934113 0.315398216 0.211177796]
[0.307279497 0.298641056 0.205946043 0.305340052 0.208093569]
[0.297960132 0.290068477 0.203082964 0.296191841 0.205141023]
[0.289444745 0.282197833 0.200336367 0.287823766 0.202311009]
[0.281623662 0.274938 0.197698563 0.280130565 0.199595287]
[0.274407148 0.26821363 0.19516255 0.2730259 0.196986347]
[0.267720908 0.261961818 0.192721918 0.266438186 0.194477364]
[0.261502862 0.256129563 0.190370843 0.260307461 0.192062095]
[0.255700678 0.250671834 0.188103959 0.254583091 0.189734846]
[0.25026986 0.245550096 0.185916349 0.249221966 0.187490389]
[0.245172322 0.24073115 0.183803499 0.244187161 0.185323924]
[0.240375236 0.236186236 0.181761235 0.239446774 0.183231026]
[0.23585014 0.231890276 0.179785714 0.234973133 0.181207627]
[0.231572226 0.22782129 0.177873373 0.230742082 0.179249957]
[0.227519721 0.223959938 0.176020905 0.226732403 0.177354515]
[0.223673478 0.220289111 0.174225256 0.22292541 0.175518081]
[0.220016524 0.216793597 0.172483563 0.219304562 0.173737645]

<tf.Tensor: id=675, shape=(5,), dtype=float32, numpy=
array([0.21653381, 0.21345986, 0.17079319, 0.21585512, 0.17201042],
      dtype=float32)>
# If you're curious you can inspect the code autograph generates.
# It feels like reading assembly language, though.

def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))
def tf__f(x):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  with ag__.FunctionScope('f', 'f_scope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as f_scope:

    def get_state():
      return ()

    def set_state(_):
      pass

    def loop_body(x):
      ag__.converted_call(tf.print, f_scope.callopts, (x,), None, f_scope)
      x = ag__.converted_call(tf.tanh, f_scope.callopts, (x,), None, f_scope)
      return x,

    def loop_test(x):
      return ag__.converted_call(tf.reduce_sum, f_scope.callopts, (x,), None, f_scope) > 1
    x, = ag__.while_stmt(loop_test, loop_body, get_state, set_state, (x,), ('x',), ())
    do_return = True
    retval_ = f_scope.mark_return_value(x)
  do_return,
  return ag__.retval(retval_)

AutoGraph: Conditionals

AutoGraph will convert if statements into the equivalent tf.cond calls.

This substitution is made if the condition is a Tensor. Otherwise, the conditional is executed during tracing.

def test_tf_cond(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'cond' for node in g.as_graph_def().node):
    print("{}({}) uses tf.cond.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) executes normally.".format(
        f.__name__, ', '.join(map(str, args))))
@tf.function
def hyperparam_cond(x, training=True):
  if training:
    x = tf.nn.dropout(x, rate=0.5)
  return x

@tf.function
def maybe_tensor_cond(x):
  if x < 0:
    x = -x
  return x

test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))
test_tf_cond(maybe_tensor_cond, tf.constant(-1))
test_tf_cond(maybe_tensor_cond, -1)
hyperparam_cond(tf.Tensor([1.], shape=(1,), dtype=float32)) executes normally.
maybe_tensor_cond(tf.Tensor(-1, shape=(), dtype=int32)) uses tf.cond.
maybe_tensor_cond(-1) executes normally.

tf.cond has a number of subtleties.

  • it works by tracing both sides of the conditional, and then choosing the appropriate branch at runtime, depending on the condition. Tracing both sides can result in unexpected execution of Python code
  • it requires that if one branch creates a tensor used downstream, the other branch must also create that tensor.
@tf.function
def f():
  x = tf.constant(0)
  if tf.constant(True):
    x = x + 1
    print("Tracing `then` branch")
  else:
    x = x - 1
    print("Tracing `else` branch")
  return x

f()
Tracing `then` branch
Tracing `else` branch

<tf.Tensor: id=747, shape=(), dtype=int32, numpy=1>
@tf.function
def f():
  if tf.constant(True):
    x = tf.ones([3, 3])
  return x

# Throws an error because both branches need to define `x`.
with assert_raises(ValueError):
  f()
Caught expected exception 
  <class 'ValueError'>: in converted code:

    <ipython-input-26-3ae6a92ff50d>:3 f  *
        if tf.constant(True):
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:893 if_stmt
        basic_symbol_names, composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:931 tf_if_stmt
        error_checking_orelse)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/util/deprecation.py:507 new_func
        return func(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/control_flow_ops.py:1174 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/cond_v2.py:91 cond_v2
        op_return_value=pred)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/func_graph.py:915 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:924 error_checking_orelse
        result[orelse_branch] = orelse()
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:962 wrapper
        new_vars = func()
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:988 wrapper
        tuple(s.symbol_name for s in undefined)))

    ValueError: The following symbols must also be initialized in the else branch: ('x',). Alternatively, you may initialize them before the if statement.

AutoGraph and loops

AutoGraph has a few simple rules for converting loops.

  • for: Convert if the iterable is a tensor
  • while: Convert if the while condition depends on a tensor

If a loop is converted, it will be dynamically unrolled with tf.while_loop, or in the special case of a for x in tf.data.Dataset, transformed into tf.data.Dataset.reduce.

If a loop is not converted, it will be statically unrolled

def test_dynamically_unrolled(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'while' for node in g.as_graph_def().node):
    print("{}({}) uses tf.while_loop.".format(
        f.__name__, ', '.join(map(str, args))))
  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
    print("{}({}) uses tf.data.Dataset.reduce.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) gets unrolled.".format(
        f.__name__, ', '.join(map(str, args))))
@tf.function
def for_in_range():
  x = 0
  for i in range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_range)
for_in_range() gets unrolled.
@tf.function
def for_in_tfrange():
  x = tf.constant(0, dtype=tf.int32)
  for i in tf.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfrange)
for_in_tfrange() uses tf.while_loop.
@tf.function
def for_in_tfdataset():
  x = tf.constant(0, dtype=tf.int64)
  for i in tf.data.Dataset.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfdataset)
for_in_tfdataset() uses tf.data.Dataset.reduce.
@tf.function
def while_py_cond():
  x = 5
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_py_cond)
while_py_cond() gets unrolled.
@tf.function
def while_tf_cond():
  x = tf.constant(5)
  while x > 0:
    x -= 1
  return x


test_dynamically_unrolled(while_tf_cond)
while_tf_cond() uses tf.while_loop.

If you have a break or early return clause that depends on a tensor, the top-level condition or iterable should also be a tensor.

Compare the following examples:

@tf.function
def while_py_true_py_break(x):
  while True:  # py true
    if x == 0: # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_py_true_py_break, 5)
while_py_true_py_break(5) gets unrolled.
@tf.function
def buggy_while_py_true_tf_break(x):
  while True:   # py true
    if tf.equal(x, 0): # tf break
      break
    x -= 1
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)
Caught expected exception 
  <class 'TypeError'>: in converted code:

    <ipython-input-34-453240ea98e6>:3 buggy_while_py_true_tf_break  *
        while True:   # py true
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:730 while_stmt
        return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:845 _py_while_stmt
        while test(*loop_vars):
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py:765 __bool__
        self._disallow_bool_casting()
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py:531 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py:518 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

@tf.function
def while_tf_true_tf_break(x):
  while tf.constant(True): # tf true
    if x == 0:  # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_tf_true_tf_break, 5)
while_tf_true_tf_break(5) uses tf.while_loop.
@tf.function
def buggy_py_for_tf_break():
  x = 0
  for i in range(5):  # py for
    if tf.equal(i, 3): # tf break
      break
    x += i
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_py_for_tf_break)
Caught expected exception 
  <class 'TypeError'>: in converted code:

    <ipython-input-36-82742b0a14d0>:4 buggy_py_for_tf_break  *
        for i in range(5):  # py for
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:339 for_stmt
        return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:348 _py_for_stmt
        if extra_test is not None and not extra_test(*state):
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py:765 __bool__
        self._disallow_bool_casting()
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py:531 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py:518 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

@tf.function
def tf_for_py_break():
  x = 0
  for i in tf.range(5): # tf for
    if i == 3:  # py break
      break
    x += i
  return x

test_dynamically_unrolled(tf_for_py_break)
tf_for_py_break() uses tf.while_loop.

In order to accumulate results from a dynamically unrolled loop, you'll want to use tf.TensorArray.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: id=1254, shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.20169294, 0.61629033, 0.1433531 , 0.4829756 ],
        [1.1323476 , 1.0190183 , 0.8961072 , 0.9289988 ],
        [1.8864099 , 1.3378067 , 1.4223019 , 1.3112907 ]],

       [[0.6474533 , 0.06091189, 0.8633839 , 0.65446174],
        [0.9992776 , 1.0039562 , 1.1452049 , 1.5506929 ],
        [1.4500049 , 1.0567949 , 1.5264317 , 2.1867056 ]]], dtype=float32)>

As with tf.cond, tf.while_loop also comes with a number of subtleties.

  • Since a loop can execute 0 times, all tensors used downstream of the while_loop must be initialized above the loop
  • The shape/dtypes of all loop variables must stay consistent with each iteration
@tf.function
def buggy_loop_var_uninitialized():
  for i in tf.range(3):
    x = i
  return x

with assert_raises(ValueError):
  buggy_loop_var_uninitialized()
Caught expected exception 
  <class 'ValueError'>: in converted code:

    <ipython-input-39-815fd6bba8cc>:3 buggy_loop_var_uninitialized  *
        for i in tf.range(3):
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:419 _tf_range_for_stmt
        _disallow_undefs_into_loop(*init_vars)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:97 _disallow_undefs_into_loop
        ' before the loop: {}'.format(tuple(s.symbol_name for s in undefined)))

    ValueError: TensorFlow requires that the following symbols must be defined before the loop: ('x',)

@tf.function
def f():
  x = tf.constant(0)
  for i in tf.range(3):
    x = i
  return x

f()
<tf.Tensor: id=1309, shape=(), dtype=int32, numpy=2>
@tf.function
def buggy_loop_type_changes():
  x = tf.constant(0, dtype=tf.float32)
  for i in tf.range(3): # Yields tensors of type tf.int32...
    x = i
  return x

with assert_raises(tf.errors.InvalidArgumentError):
  buggy_loop_type_changes()
Got unexpected exception 
  <class 'TypeError'>: in converted code:

    <ipython-input-41-46359fb065eb>:4 buggy_loop_type_changes  *
        for i in tf.range(3): # Yields tensors of type tf.int32...
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:478 _tf_range_for_stmt
        opts=opts,
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:769 _tf_while_stmt
        aug_init_vars, **opts)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/control_flow_ops.py:2675 while_loop
        back_prop=back_prop)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/while_v2.py:198 while_loop
        add_control_dependencies=add_control_dependencies)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/func_graph.py:915 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/while_v2.py:176 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:759 aug_body
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:195 _verify_tf_loop_vars
        first_iter_var)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/util/nest.py:535 map_structure
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/util/nest.py:535 <listcomp>
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:179 _check_same_type
        first_iter_var.dtype.name,

    TypeError: "x" has dtype float32 before the loop, but dtype int32 after one iteration. TensorFlow control flow requires it stays the same.

@tf.function
def buggy_concat():
  x = tf.ones([0, 10])
  for i in tf.range(5):
    x = tf.concat([x, tf.ones([1, 10])], axis=0)
  return x

with assert_raises(ValueError):
  buggy_concat()
Caught expected exception 
  <class 'ValueError'>: in converted code:

    <ipython-input-42-bae298a1ce41>:4 buggy_concat  *
        for i in tf.range(5):
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:478 _tf_range_for_stmt
        opts=opts,
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:769 _tf_while_stmt
        aug_init_vars, **opts)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/control_flow_ops.py:2675 while_loop
        back_prop=back_prop)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/while_v2.py:198 while_loop
        add_control_dependencies=add_control_dependencies)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/func_graph.py:915 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/while_v2.py:176 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:759 aug_body
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:195 _verify_tf_loop_vars
        first_iter_var)
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/util/nest.py:535 map_structure
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/util/nest.py:535 <listcomp>
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:191 _check_same_type
        first_iter_shape))

    ValueError: "x" has shape (0, 10) before the loop, but shape (1, 10) after one iteration. TensorFlow control flow requires it stays the same or be more specific.

@tf.function
def concat_with_padding():
  x = tf.zeros([5, 10])
  for i in tf.range(5):
    x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)
    x.set_shape([5, 10])
  return x

concat_with_padding()
<tf.Tensor: id=1432, shape=(5, 10), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>