Better performance with tf.function

View on Run in Google Colab View source on GitHub Download notebook

In TensorFlow 2, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier and faster), but this can come at the expense of performance and deployability.

To get performant and portable models, use tf.function to make graphs out of your programs. However there are pitfalls to be wary of - tf.function is not a magical make-it-faster bullet!

This document will help you conceptualize what tf.function is doing under the hood, so that you can master its use.

The main takeaways and recommendations are:

  • Debug in Eager mode, then decorate with @tf.function.
  • Don't rely on Python side effects like object mutation or list appends.
  • tf.function works best with TensorFlow ops; NumPy and Python calls are converted to constants.


import tensorflow as tf

Define a helper function to demonstrate the kinds of errors you might encounter:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
def assert_raises(error_class):
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
  except Exception as e:
    raise e
    raise Exception('Expected {} to be raised but no error was raised!'.format(


A tf.function you define is just like a core TensorFlow operation: You can execute it eagerly; you can compute gradients; and so on.

def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

You can use functions inside functions.

def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Functions can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

Eager conv: 0.004070537999723456
Function conv: 0.0023154040000008536
Note how there's not much difference in performance for convolutions


In general, debugging code is easier in Eager mode than inside a tf.function. You should ensure that your code executes error-free in Eager mode before decorating with tf.function. To assist in the debugging process, you can call tf.config.run_functions_eagerly(True) to globally disable and reenable tf.function.

When tracking down issues that only appear within tf.function, here are some tips:

  • Plain old Python print calls only execute during tracing, helping you track down when your functions get (re)traced.
  • tf.print calls will execute every time, and can help you track down intermediate values during execution.
  • tf.debugging.enable_check_numerics is an easy way to track down where NaNs and Inf are created.
  • pdb can help you understand what's going on during tracing. (Caveat: PDB will drop you into AutoGraph-transformed source code.)

Tracing and polymorphism

Python's dynamic typing means that you can call functions with a variety of argument types, and Python will do something different in each scenario.

On the other hand, TensorFlow graphs require static dtypes and shape dimensions. tf.function bridges this gap by retracing the function when necessary to generate the correct graphs. Most of the subtlety of tf.function usage stems from this retracing behavior.

You can call a function with arguments of different types to see what is happening.

# Functions are polymorphic

def double(a):
  print("Tracing with", a)
  return a + a


Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

To control the tracing behavior, you can use the following techniques:

Create a new tf.function. Separate tf.function objects are guaranteed not to share traces.

def f():


Use get_concrete_function method to get a specific trace.

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
Obtaining concrete trace
Tracing with Tensor("a:0", dtype=string)
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
Using a concrete trace with incompatible types will throw an error
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
  File "<ipython-input-10-5351d0a2eda2>", line 8, in <module>
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_183 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_183]

Specify input_signature in tf.function to limit tracing.

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
  File "<ipython-input-11-9939c82c1507>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

When to retrace?

A polymorphic tf.function keeps a cache of concrete functions generated by tracing. The cache keys are effectively tuples of keys generated from the function args and kwargs. The key generated for a tf.Tensor argument is its number of dimensions and type. The key generated for a Python primitive is its value. For all other Python types, the keys are based on the object id() so that methods are traced independently for each instance of a class. In the future, TensorFlow may add more sophisticated cachi ng for Python objects that can be safely converted to tensors.

See Concrete functions

Python or Tensor args?

Often, Python arguments are used to control hyperparameters and graph constructions - for example, num_layers=10 or training=True or nonlinearity='relu'. So if the Python argument changes, it makes sense that you'd have to retrace the graph.

However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so this is a bit inefficient.

def train_one_step():

def train(num_steps):
  print("Tracing with num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):


Tracing with num_steps = 10
Tracing with num_steps = 20

The simple workaround here is to cast your arguments to Tensors if they do not affect the shape of the generated graph.

Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32)

Side effects in tf.function

In general, Python side effects (like printing or mutating objects) only happen during tracing. So how can you reliably trigger side effects from tf.function?

The general rule of thumb is to only use Python side effects to debug your traces. Otherwise, TensorFlow ops like tf.Variable.assign, tf.print, and tf.summary are the best way to ensure your code will be traced and executed by the TensorFlow runtime with each call. In general using a functional style will yield the best results.

def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)


Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

If you would like to execute Python code during each invocation of a tf.function, tf.py_function is an exit hatch. The drawback of tf.py_function is that it's not portable or particularly performant, nor does it work well in distributed (multi-GPU, TPU) setups. Also, since tf.py_function has to be wired into the graph for differentiability, it casts all inputs/outputs to tensors.

external_list = []

def side_effect(x):
  print('Python side effect')

def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1

Python side effect
Python side effect
Python side effect

Beware of Python state

Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in Eager mode, many unexpected things can happen inside a tf.function due to tracing behavior.

To give one example, advancing iterator state is a Python side effect and therefore only happens during tracing.

external_var = tf.Variable(0)
def buggy_consume_next(iterator):
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
# This reuses the first value from the iterator, rather than consuming the next value.

Value of external_var: 0
Value of external_var: 0
Value of external_var: 0


We can use the same idea of leveraging the intended execution order of the code to make variable creation and utilization very easy in tf.function. There is one very important caveat, though, which is that with variables it's possible to write code which behaves differently in eager mode and graph mode.

Specifically, this will happen when you create a new Variable with each call. Due to tracing semantics, tf.function will reuse the same variable each call, but eager mode will create a new variable with each call. To guard against this mistake, tf.function will raise an error if it detects dangerous variable creation behavior.

def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/ calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
  File "<ipython-input-17-73e410646579>", line 8, in <module>
ValueError: in user code:

    <ipython-input-17-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/ __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/ _variable_v2_call
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/ getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/ invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

Non-ambiguous code is ok, though.

v = tf.Variable(1.0)

def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0

tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

You can also create variables inside a tf.function as long as we can prove that those variables are created only the first time the function is executed.

class C:

obj = C()
obj.v = None

def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

Variable initializers can depend on function arguments and on values of other variables. We can figure out the right initialization order using the same method we use to generate control dependencies.

state = []
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]


tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)

AutoGraph Transformations

AutoGraph is a library that is on by default in tf.function, and transforms a subset of Python Eager code into graph-compatible TensorFlow ops. This includes control flow like if, for, while.

TensorFlow ops like tf.cond and tf.while_loop continue to work, but control flow is often easier to write and understand when written in Python.

# Simple loop

def f(x):
  while tf.reduce_sum(x) > 1:
    x = tf.tanh(x)
  return x

[0.42992723 0.425026417 0.735794306 0.224515557 0.623353]
[0.405260503 0.401156455 0.626597464 0.2208177 0.553458273]
[0.384441048 0.380938053 0.555704892 0.217297286 0.503107667]
[0.366558045 0.363521844 0.50478363 0.213940546 0.464557648]
[0.350977391 0.348312378 0.465870887 0.210735157 0.433791548]
[0.337242037 0.334878027 0.434857041 0.207670063 0.408485085]
[0.325012982 0.322897047 0.409372419 0.204735294 0.387185633]
[0.314032614 0.312124074 0.387939692 0.201921865 0.368931442]
[0.304101259 0.302368194 0.369582683 0.199221611 0.353056699]
[0.29506132 0.29347834 0.353626639 0.19662714 0.339083582]
[0.286786556 0.285333127 0.339587897 0.194131717 0.326659]
[0.279174477 0.27783379 0.327109426 0.191729173 0.315515548]
[0.272140861 0.270899028 0.315921068 0.18941389 0.305446446]
[0.265615791 0.26446119 0.305814087 0.187180698 0.296288908]
[0.259540617 0.258463472 0.296624243 0.185024858 0.287912786]
[0.253865808 0.252857804 0.288220286 0.182941988 0.280212611]
[0.248549119 0.247603148 0.280495942 0.180928066 0.273101836]
[0.243554324 0.242664263 0.273364 0.178979367 0.266508728]
[0.238850132 0.238010675 0.266752273 0.177092433 0.260373205]
[0.234409362 0.233615875 0.260600239 0.17526406 0.254644573]
[0.230208248 0.229456678 0.254856855 0.173491284 0.249279633]
[0.226225957 0.225512728 0.249478713 0.171771348 0.244241387]
[0.222444087 0.22176604 0.244428575 0.170101658 0.2394979]
[0.218846336 0.218200669 0.23967433 0.16847983 0.235021442]
[0.215418205 0.214802414 0.235188112 0.1669036 0.230787814]
[0.212146759 0.21155861 0.230945602 0.165370882 0.22677578]
[0.209020391 0.208457872 0.226925448 0.163879707 0.222966641]
[0.20602867 0.205489963 0.223108858 0.162428215 0.219343811]
[0.203162178 0.202645645 0.219479173 0.161014691 0.215892553]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20041241, 0.19991657, 0.2160216 , 0.1596375 , 0.21259971],

If you're curious you can inspect the code autograph generates.

def tf__f(x):
    do_return = False
    retval_ = ag__.UndefinedReturnValue()
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:

        def get_state():
            return (x,)

        def set_state(loop_vars):
            nonlocal x
            (x,) = loop_vars

        def loop_body():
            nonlocal x
            ag__.converted_call(tf.print, (x,), None, fscope)
            x = ag__.converted_call(tf.tanh, (x,), None, fscope)

        def loop_test():
            return (ag__.converted_call(tf.reduce_sum, (x,), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
            do_return = True
            retval_ = fscope.mark_return_value(x)
            do_return = False
    return ag__.retval(retval_)


AutoGraph will convert some if <condition> statements into the equivalent tf.cond calls. This substitution is made if <condition> is a Tensor. Otherwise, the if statement is executed as a Python conditional.

A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.

tf.cond traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. Tracing can have unintended side effects; see AutoGraph tracing effects for more.

def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
    elif i % 3 == 0:
      print('Tracing fizz branch')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      print('Tracing default branch')

Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch

See the reference documentation for additional restrictions on AutoGraph-converted if statements.


AutoGraph will convert some for and while statements into the equivalent TensorFlow looping ops, like tf.while_loop. If not converted, the for or while loop is executed as a Python loop.

This substitution is made in the following situations:

  • for x in y: if y is a Tensor, convert to tf.while_loop. In the special case where y is a, a combination of ops are generated.
  • while <condition>: if <condition> is a Tensor, convert to tf.while_loop.

A Python loop executes during tracing, adding additional ops to the tf.Graph for every iteration of the loop.

A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to run at execution time. The loop body only appears once in the generated tf.Graph.

See the reference documentation for additional restrictions on AutoGraph-converted for and while statements.

Looping over Python data

A common pitfall is to loop over Python/Numpy data within a tf.function. This loop will execute during the tracing process, adding a copy of your model to the tf.Graph for each iteration of the loop.

If you want to wrap the entire training loop in tf.function, the safest way to do this is to wrap your data as a so that AutoGraph will dynamically unroll the training loop.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

    lambda: small_data, (tf.int32, tf.int32)))
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

When wrapping Python/Numpy data in a Dataset, be mindful of versus The former will keep the data in Python and fetch it via tf.py_function which can have performance implications, whereas the latter will bundle a copy of the data as one large tf.constant() node in the graph, which can have memory implications.

Reading data from files via TFRecordDataset/CsvDataset/etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python. To learn more, see the guide.

Accumulating values in a loop

A common pattern is to accumulate intermediate values from a loop. Normally, this is accomplished by appending to a Python list or adding entries to a Python dictionary. However, as these are Python side effects, they will not work as expected in a dynamically unrolled loop. Use tf.TensorArray to accumulate results from a dynamically unrolled loop.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.96471524, 0.233114  , 0.1417228 , 0.14083493],
        [1.6257136 , 0.9389272 , 0.73989546, 0.8011714 ],
        [2.233508  , 1.827873  , 1.1567426 , 1.5585394 ]],

       [[0.67377114, 0.42712367, 0.5697857 , 0.71173656],
        [1.5520021 , 0.806401  , 0.9260858 , 1.5265073 ],
        [1.8115815 , 1.6316041 , 1.2245122 , 1.9724467 ]]], dtype=float32)>

Further reading

To learn more about graph optimizations that are performed after tracing a tf.function, see the Grappler guide. To learn how to optimize your data pipeline and profile your model, see the Profiler guide.