Help protect the Great Barrier Reef with TensorFlow on Kaggle

Better performance with tf.function

In TensorFlow 2, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier and faster), but this can come at the expense of performance and deployability.

You can use tf.function to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use SavedModel.

This guide will help you conceptualize how tf.function works under the hood, so you can use it effectively.

The main takeaways and recommendations are:

• Debug in eager mode, then decorate with @tf.function.
• Don't rely on Python side effects like object mutation or list appends.
• tf.function works best with TensorFlow ops; NumPy and Python calls are converted to constants.

Setup

import tensorflow as tf

Define a helper function to demonstrate the kinds of errors you might encounter:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n  {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))

Basics

Usage

A Function you define (for example by applying the @tf.function decorator) is just like a core TensorFlow operation: You can execute it eagerly; you can compute gradients; and so on.

@tf.function  # The decorator converts `add` into a `Function`.
return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
[2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

You can use Functions inside other Functions.

@tf.function
def dense_layer(x, w, b):

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones())
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
[3., 3.],
[3., 3.]], dtype=float32)>

Functions can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

Tracing

This section exposes how Function works under the hood, including implementation details which may change in the future. However, once you understand why and when tracing happens, it's much easier to use tf.function effectively!

What is "tracing"?

A Function runs your program in a TensorFlow Graph. However, a tf.Graph cannot represent all the things that you'd write in an eager TensorFlow program. For instance, Python supports polymorphism, but tf.Graph requires its inputs to have a specified data type and dimension. Or you may perform side tasks like reading command-line arguments, raising an error, or working with a more complex Python object; none of these things can run in a tf.Graph.

Function bridges this gap by separating your code in two stages:

1) In the first stage, referred to as "tracing", Function creates a new tf.Graph. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are deferred: they are captured by the tf.Graph and not run.

2) In the second stage, a tf.Graph which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage.

Depending on its inputs, Function will not always run the first stage when it is called. See "Rules of tracing" below to get a better sense of how it makes that determination. Skipping the first stage and only executing the second stage is what gives you TensorFlow's high performance.

When Function does decide to trace, the tracing stage is immediately followed by the second stage, so calling the Function both creates and runs the tf.Graph. Later you will see how you can run only the tracing stage with get_concrete_function.

When you pass arguments of different types into a Function, both stages are run:

@tf.function
def double(a):
print("Tracing with", a)
return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

Note that if you repeatedly call a Function with the same argument type, TensorFlow will skip the tracing stage and reuse a previously traced graph, as the generated graph would be identical.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

You can use pretty_printed_concrete_signatures() to see all of the available traces:

print(double.pretty_printed_concrete_signatures())
double(a)
Args:
a: int32 Tensor, shape=()
Returns:
int32 Tensor, shape=()

double(a)
Args:
a: float32 Tensor, shape=()
Returns:
float32 Tensor, shape=()

double(a)
Args:
a: string Tensor, shape=()
Returns:
string Tensor, shape=()

So far, you've seen that tf.function creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic. To be more specific about the terminology:

• A tf.Graph is the raw, language-agnostic, portable representation of a TensorFlow computation.
• A ConcreteFunction wraps a tf.Graph.
• A Function manages a cache of ConcreteFunctions and picks the right one for your inputs.
• tf.function wraps a Python function, returning a Function object.
• Tracing creates a tf.Graph and wraps it in a ConcreteFunction, also known as a trace.

Rules of tracing

A Function determines whether to reuse a traced ConcreteFunction by computing a cache key from an input's args and kwargs. A cache key is a key that identifies a ConcreteFunction based on the input args and kwargs of the Function call, according to the following rules (which may change):

• The key generated for a tf.Tensor is its shape and dtype.
• The key generated for a tf.Variable is a unique variable id.
• The key generated for a Python primitive (like int, float, str) is its value.
• The key generated for nested dicts, lists, tuples, namedtuples, and attrs is the flattened tuple of leaf-keys (see nest.flatten). (As a result of this flattening, calling a concrete function with a different nesting structure than the one used during tracing will result in a TypeError).
• For all other Python types the key is unique to the object. This way a function or method is traced independently for each instance it is called with.

Controlling retracing

Retracing, which is when your Function creates more than one trace, helps ensures that TensorFlow generates correct graphs for each set of inputs. However, tracing is an expensive operation! If your Function retraces a new graph for every call, you'll find that your code executes more slowly than if you didn't use tf.function.

To control the tracing behavior, you can use the following techniques:

• Specify input_signature in tf.function to limit tracing.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception
<class 'ValueError'>:
Caught expected exception
<class 'ValueError'>:
Traceback (most recent call last):
File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
yield
File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
inputs: (
tf.Tensor(
[[1 2]
[3 4]], shape=(2, 2), dtype=int32))
input_signature: (
TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
yield
File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
inputs: (
tf.Tensor([1. 2.], shape=(2,), dtype=float32))
input_signature: (
TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
• Specify a [None] dimension in tf.TensorSpec to allow for flexibility in trace reuse.

Since TensorFlow matches tensors based on their shape, using a None dimension as a wildcard will allow Functions to reuse traces for variably-sized input. Variably-sized input can occur if you have sequences of different length, or images of different sizes for each batch (See the Transformer and Deep Dream tutorials for example).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
• Cast Python arguments to Tensors to reduce retracing.

Often, Python arguments are used to control hyperparameters and graph constructions - for example, num_layers=10 or training=True or nonlinearity='relu'. So, if the Python argument changes, it makes sense that you'd have to retrace the graph.

However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so retracing is unnecessary.

def train_one_step():
pass

@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

If you need to force retracing, create a new Function. Separate Function objects are guaranteed not to share traces.

def f():
print('Tracing!')
tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Obtaining concrete functions

Every time a function is traced, a new concrete function is created. You can directly obtain a concrete function, by using get_concrete_function.

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

Printing a ConcreteFunction displays a summary of its input arguments (with types) and its output type.

print(double_strings)
ConcreteFunction double(a)
Args:
a: string Tensor, shape=()
Returns:
string Tensor, shape=()

You can also directly retrieve a concrete function's signature.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Using a concrete trace with incompatible types will throw an error

with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
yield
File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

You may notice that Python arguments are given special treatment in a concrete function's input signature. Prior to TensorFlow 2.3, Python arguments were simply removed from the concrete function's signature. Starting with TensorFlow 2.3, Python arguments remain in the signature, but are constrained to take the value set during tracing.

@tf.function
def pow(a, b):
return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
Args:
a: float32 Tensor, shape=<unknown>
Returns:
float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception
<class 'TypeError'>:
Traceback (most recent call last):
File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
cancellation_manager)
File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
yield
File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

Obtaining graphs

Each concrete function is a callable wrapper around a tf.Graph. Although retrieving the actual tf.Graph object is not something you'll normally need to do, you can obtain it easily from any concrete function.

graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a

Debugging

In general, debugging code is easier in eager mode than inside tf.function. You should ensure that your code executes error-free in eager mode before decorating with tf.function. To assist in the debugging process, you can call tf.config.run_functions_eagerly(True) to globally disable and reenable tf.function.

When tracking down issues that only appear within tf.function, here are some tips:

• Plain old Python print calls only execute during tracing, helping you track down when your function gets (re)traced.
• tf.print calls will execute every time, and can help you track down intermediate values during execution.
• tf.debugging.enable_check_numerics is an easy way to track down where NaNs and Inf are created.
• pdb (the Python debugger) can help you understand what's going on during tracing. (Caveat: pdb will drop you into AutoGraph-transformed source code.)

AutoGraph transformations

AutoGraph is a library that is on by default in tf.function, and transforms a subset of Python eager code into graph-compatible TensorFlow ops. This includes control flow like if, for, while.

TensorFlow ops like tf.cond and tf.while_loop continue to work, but control flow is often easier to write and understand when written in Python.

# A simple loop

@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x

f(tf.random.uniform())
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
dtype=float32)>

If you're curious you can inspect the code autograph generates.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
do_return = False
retval_ = ag__.UndefinedReturnValue()

def get_state():
return (x,)

def set_state(vars_):
nonlocal x
(x,) = vars_

def loop_body():
nonlocal x
ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

def loop_test():
return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
try:
do_return = True
retval_ = ag__.ld(x)
except:
do_return = False
raise
return fscope.ret(retval_, do_return)

Conditionals

AutoGraph will convert some if <condition> statements into the equivalent tf.cond calls. This substitution is made if <condition> is a Tensor. Otherwise, the if statement is executed as a Python conditional.

A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.

tf.cond traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. Tracing can have unintended side effects; check out AutoGraph tracing effects for more information.

@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

See the reference documentation for additional restrictions on AutoGraph-converted if statements.

Loops

AutoGraph will convert some for and while statements into the equivalent TensorFlow looping ops, like tf.while_loop. If not converted, the for or while loop is executed as a Python loop.

This substitution is made in the following situations:

• for x in y: if y is a Tensor, convert to tf.while_loop. In the special case where y is a tf.data.Dataset, a combination of tf.data.Dataset ops are generated.
• while <condition>: if <condition> is a Tensor, convert to tf.while_loop.

A Python loop executes during tracing, adding additional ops to the tf.Graph for every iteration of the loop.

A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to run at execution time. The loop body only appears once in the generated tf.Graph.

See the reference documentation for additional restrictions on AutoGraph-converted for and while statements.

Looping over Python data

A common pitfall is to loop over Python/NumPy data within a tf.function. This loop will execute during the tracing process, adding a copy of your model to the tf.Graph for each iteration of the loop.

If you want to wrap the entire training loop in tf.function, the safest way to do this is to wrap your data as a tf.data.Dataset so that AutoGraph will dynamically unroll the training loop.

def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

When wrapping Python/NumPy data in a Dataset, be mindful of tf.data.Dataset.from_generator versus tf.data.Dataset.from_tensors. The former will keep the data in Python and fetch it via tf.py_function which can have performance implications, whereas the latter will bundle a copy of the data as one large tf.constant() node in the graph, which can have memory implications.

Reading data from files via TFRecordDataset, CsvDataset, etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python. To learn more, see the tf.data: Build TensorFlow input pipelines guide.

Accumulating values in a loop

A common pattern is to accumulate intermediate values from a loop. Normally, this is accomplished by appending to a Python list or adding entries to a Python dictionary. However, as these are Python side effects, they will not work as expected in a dynamically unrolled loop. Use tf.TensorArray to accumulate results from a dynamically unrolled loop.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape

states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
[0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
[0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

[[0.04946005, 0.69127274, 0.56848884, 0.22406638],
[0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
[0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

Limitations

TensorFlow Function has a few limitations by design that you should be aware of when converting a Python function to a Function.

Executing Python side effects

Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a Function, sometimes executing twice or not all. They only happen the first time you call a Function with a set of inputs. Afterwards, the traced tf.Graph is reexecuted, without executing the Python code.

The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces. Otherwise, TensorFlow APIs like tf.data, tf.print, tf.summary, tf.Variable.assign, and tf.TensorArray are the best way to ensure your code will be executed by the TensorFlow runtime with each call.

@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

If you would like to execute Python code during each invocation of a Function, tf.py_function is an exit hatch. The drawback of tf.py_function is that it's not portable or particularly performant, cannot be saved with SavedModel, and does not work well in distributed (multi-GPU, TPU) setups. Also, since tf.py_function has to be wired into the graph, it casts all inputs/outputs to tensors.

Changing Python global and free variables

Changing Python global and free variables counts as a Python side effect, so it only happens during tracing.

external_list = []

@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

Sometimes unexpected behaviors are very hard to notice. In the example below, the counter is intended to safeguard the increment of a variable. However because it is a python integer and not a TensorFlow object, it's value is captured during the first trace. When the tf.function is used, the assign_add will be recorded unconditionally in the underlying graph. Therefore v will increase by 1, every time the tf.function is called. This issue is common among users that try to migrate their Grpah-mode Tensorflow code to Tensorflow 2 using tf.function decorators, when python side-effects (the counter in the example) are used to determine what ops to run (assign_add in the example). Usually, users realize this only after seeing suspicious numerical results, or significantly lower performance than expected (e.g. if the guarded operation is very costly).

class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0

@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1

return self.v

m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1
2
3

A workaround to achieve the expected behavior is using tf.init_scope to lift the operations outside of the function graph. This ensures that the variable increment is only done once during tracing time. It should be noted init_scope has other side effects including cleared control flow and gradient tape. Sometimes the usage of init_scope can become too complex to manage realistically.

class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0

@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1

return self.v

m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1
1
1

In summary, as a rule of thumb, you should avoid mutating python objects such as integers or containers like lists that live outside the Function. Instead, use arguments and TF objects. For example, the section "Accumulating values in a loop" has one example of how list-like operations can be implemented.

You can, in some cases, capture and manipulate state if it is a tf.Variable. This is how the weights of Keras models are updated with repeated calls to the same ConcreteFunction.

Using Python iterators and generators

Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in eager mode, they are examples of Python side effects and therefore only happen during tracing.

@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

Just like how TensorFlow has a specialized tf.TensorArray for list constructs, it has a specialized tf.data.Iterator for iteration constructs. See the section on AutoGraph transformations for an overview. Also, the tf.data API can help implement generator patterns:

@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

All outputs of a tf.function must be return values

With the exception of tf.Variables, a tf.function must return all its outputs. Attempting to directly access any tensors from a function without going through return values causes "leaks".

For example, the function below "leaks" the tensor a through the Python global x:

x = None

@tf.function
def leaky_function(a):
global x
x = a + 1  # Bad - leaks local tensor
return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3
'Tensor' object has no attribute 'numpy'

This is true even if the leaked value is also returned:

@tf.function
def leaky_function(a):
global x
x = a + 1  # Bad - leaks local tensor
return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)

@tf.function
def captures_leaked_tensor(b):
b += x  # Bad - `x` is leaked from `leaky_function`
return b

with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception
<class 'TypeError'>:
Traceback (most recent call last):
File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
yield
File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

Usually, leaks such as these occur when you use Python statements or data structures. In addition to leaking inaccessible tensors, such statements are also likely wrong because they count as Python side effects, and are not guaranteed to execute at every function call.

Common ways to leak local tensors also include mutating an external Python collection, or an object:

class MyClass:

def __init__(self):
self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
a = tf.constant(1)
external_list.append(a)  # Bad - leaks tensor
external_object.field = a  # Bad - leaks tensor

Recursive tf.functions are not supported

Recursive Functions are not supported and could cause infinite loops. For example,

@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1

with assert_raises(Exception):
recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception
<class 'Exception'>:
Traceback (most recent call last):
File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
yield
File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
return recursive_fn(n - 1)
File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
if n > 0:
File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
return _abc_instancecheck(cls, instance)

RecursionError: maximum recursion depth exceeded while calling a Python object

Even if a recursive Function seems to work, the python function will be traced multiple times and could have performance implication. For example,

@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

Known Issues

If your Function is not evaluating correctly, the error may be explained by these known issues which are planned to be fixed in the future.

Depending on Python global and free variables

Function creates a new ConcreteFunction when called with a new value of a Python argument. However, it does not do that for the Python closure, globals, or nonlocals of that Function. If their value changes in between calls to the Function, the Function will still use the values they had when it was traced. This is different from how regular Python functions work.

For that reason, you should follow a functional programming style that uses arguments instead of closing over outer names.

@tf.function
return 1 + foo

@tf.function
return 1 + foo

foo = 1
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

Another way to update a global value, is to make it a tf.Variable and use the Variable.assign method instead.

@tf.function
return 1 + foo

foo = tf.Variable(1)
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

Depending on Python objects

The recommendation to pass Python objects as arguments into tf.function has a number of known issues, that are expected to be fixed in the future. In general, you can rely on consistent tracing if you use a Python primitive or tf.nest-compatible structure as an argument or pass in a different instance of an object into a Function. However, Function will not create a new trace when you pass the same object and only change its attributes.

class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.

@tf.function
def evaluate(model, x):
return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
tf.Tensor(20.0, shape=(), dtype=float32)

Using the same Function to evaluate the updated instance of the model will be buggy since the updated model has the same cache key as the original model.

For that reason, you're recommended to write your Function to avoid depending on mutable object attributes or create new objects.

If that is not possible, one workaround is to make new Functions each time you modify your object to force retracing:

def evaluate(model, x):
return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
tf.Tensor(25.0, shape=(), dtype=float32)

As retracing can be expensive, you can use tf.Variables as object attributes, which can be mutated (but not changed, careful!) for a similar effect without needing a retrace.

class BetterModel:

def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print(evaluate(better_model, x))  # This works!
tf.Tensor(25.0, shape=(), dtype=float32)

Creating tf.Variables

Function only supports singleton tf.Variables created once on the first call, and reused across subsequent function calls. The code snippet below would create a new tf.Variable in every function call, which results in a ValueError exception.

Example:

@tf.function
def f(x):
v = tf.Variable(1.0)
return v

with assert_raises(ValueError):
f(1.0)
Caught expected exception
<class 'ValueError'>:
Traceback (most recent call last):
File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
yield
File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
f(1.0)
ValueError: in user code:

File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
v = tf.Variable(1.0)

ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

A common pattern used to work around this limitation is to start with a Python None value, then conditionally create the tf.Variable if the value is None:

class Count(tf.Module):
def __init__(self):
self.count = None

@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Using with multiple Keras optimizers

You may encounter ValueError: tf.function only supports singleton tf.Variables created on the first call. when using more than one Keras optimizer with a tf.function. This error occurs because optimizers internally create tf.Variables when they apply gradients for the first time.

@tf.function
def train_step(w, x, y, optimizer):
L = tf.reduce_sum(tf.square(w*x - y))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception
<class 'ValueError'>:
Traceback (most recent call last):
File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
yield
File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
train_step(w, x, y, opt2)
ValueError: in user code:

File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
self._create_all_weights(var_list)
File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
_ = self.iterations
File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
return super(OptimizerV2, self).__getattribute__(name)
File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
aggregation=aggregation)
File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
shape=variable_shape if variable_shape else None)

ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

If you need to change the optimizer during training, a workaround is to create a new Function for each optimizer, calling the ConcreteFunction directly.

# Not a tf.function.
def train_step(w, x, y, optimizer):
L = tf.reduce_sum(tf.square(w*x - y))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y) # `opt1` is not used as a parameter.
else:
train_step_2(w, x, y) # `opt2` is not used as a parameter.

Using with multiple Keras models

You may also encounter ValueError: tf.function only supports singleton tf.Variables created on the first call. when passing different model instances to the same Function.

This error occurs because Keras models (which do not have their input shape defined) and Keras layers create tf.Variabless when they are first called. You may be attempting to initialize those variables inside a Function, which has already been called. To avoid this error, try calling model.build(input_shape) to initialize all the weights before training the model.

To learn about how to export and load a Function, see the SavedModel guide. To learn more about graph optimizations that are performed after tracing, see the Grappler guide. To learn how to optimize your data pipeline and profile your model, see the Profiler guide.

[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"Missing the information I need" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"Too complicated / too many steps" },{ "type": "thumb-down", "id": "outOfDate", "label":"Out of date" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"Samples / code issue" },{ "type": "thumb-down", "id": "otherDown", "label":"Other" }]
[{ "type": "thumb-up", "id": "easyToUnderstand", "label":"Easy to understand" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"Solved my problem" },{ "type": "thumb-up", "id": "otherUp", "label":"Other" }]