Introduction to graphs and tf.functions

View on Run in Google Colab View source on GitHub Download notebook


This guide goes beneath the surface of TensorFlow and Keras to see how TensorFlow works. If you instead want to immediately get started with Keras, please see our collection of Keras guides.

In this guide you'll see the core of how TensorFlow allows you to make simple changes to your code to get graphs, and how they are stored and represented, and how you can use them to accelerate and export your models.

This is a short-form introduction; for a full introduction to these concepts, see the tf.function guide.

What are graphs?

In the previous three guides, you have seen TensorFlow running eagerly. This means TensorFlow operations are executed by Python, operation by operation, and returning results back to Python. Eager TensorFlow takes advantage of GPUs, allowing you to place variables, tensors, and even operations on GPUs and TPUs. It is also easy to debug.

For some users, you may never need or want to leave Python.

However, running TensorFlow op-by-op in Python prevents a host of accelerations otherwise available. If you can extract tensor computations from Python, you can make them into a graph.

Graphs are data structures that contain a set of tf.Operation objects, which represent units of computation; and tf.Tensor objects, which represent the units of data that flow between operations. They are defined in a tf.Graph context. Since these graphs are data structures, they can be saved, run, and restored all without the original Python code.

This is what a simple two-layer graph looks like when visualized in TensorBoard.

a two-layer tensorflow graph

The benefits of graphs

With a graph, you have a great deal of flexibility. You can use your TensorFlow graph in environments that don't have a Python interpreter, like mobile applications, embedded devices, and backend servers. TensorFlow uses graphs as the format for saved models when it exports them from Python.

Graphs are also easily optimized, allowing the compiler to do transformations like:

  • Statically infer the value of tensors by folding constant nodes in your computation ("constant folding").
  • Separate sub-parts of a computation that are independent and split them between threads or devices.
  • Simplify arithmetic operations by eliminating common subexpressions.

There is an entire optimization system, Grappler, to perform this and other speedups.

In short, graphs are extremely useful and let your TensorFlow run fast, run in parallel, and run efficiently on multiple devices.

However, you still want to define our machine learning models (or other computations) in Python for convenience, and then automatically construct graphs when you need them.

Tracing graphs

The way you create a graph in TensorFlow is to use tf.function, either as a direct call or as a decorator.

import tensorflow as tf
import timeit
from datetime import datetime
# Define a Python function
def function_to_get_faster(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Create a `Function` object that contains a graph
a_function_that_uses_a_graph = tf.function(function_to_get_faster)

# Make some tensors
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

# It just works!
a_function_that_uses_a_graph(x1, y1, b1).numpy()
array([[12.]], dtype=float32)

tf.function-ized functions are Python callables that work the same as their Python equivalents. They have a particular class (python.eager.def_function.Function), but to you they act just as the non-traced version.

tf.function recursively traces any Python function it calls.

def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes inner_function() as well as outer_function()
outer_function(tf.constant([[1.0, 2.0]])).numpy()
array([[12.]], dtype=float32)

If you have used TensorFlow 1.x, you will notice that at no time did you need to define a Placeholder or tf.Sesssion.

Flow control and side effects

Flow control and loops are converted to TensorFlow via tf.autograph by default. Autograph uses a combination of methods, including standardizing loop constructs, unrolling, and AST manipulation.

def my_function(x):
  if tf.reduce_sum(x) <= 1:
    return x * x
    return x-1

a_function = tf.function(my_function)

print("First branch, with graph:", a_function(tf.constant(1.0)).numpy())
print("Second branch, with graph:", a_function(tf.constant([5.0, 5.0])).numpy())
First branch, with graph: 1.0
Second branch, with graph: [4. 4.]

You can directly call the Autograph conversion to see how Python is converted into TensorFlow ops. This is, mostly, unreadable, but you can see the transformation.

# Don't read the output too carefully.
def tf__my_function(x):
    with ag__.FunctionScope('my_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal do_return, retval_
            (do_return, retval_) = vars_

        def if_body():
            nonlocal do_return, retval_
                do_return = True
                retval_ = (ag__.ld(x) * ag__.ld(x))
                do_return = False

        def else_body():
            nonlocal do_return, retval_
                do_return = True
                retval_ = (ag__.ld(x) - 1)
                do_return = False
        ag__.if_stmt((ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) <= 1), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        return fscope.ret(retval_, do_return)

Autograph automatically converts if-then clauses, loops, break, return, continue, and more.

Most of the time, Autograph will work without special considerations. However, there are some caveats, and the tf.function guide can help here, as well as the complete autograph reference

Seeing the speed up

Just wrapping a tensor-using function in tf.function does not automatically speed up your code. For small functions called a few times on a single machine, the overhead of calling a graph or graph fragment may dominate runtime. Also, if most of the computation was already happening on an accelerator, such as stacks of GPU-heavy convolutions, the graph speedup won't be large.

For complicated computations, graphs can provide a significant speedup. This is because graphs reduce the Python-to-device communication and perform some speedups.

The speedup is most obvious when running many small layers, as in the example below:

# Create an oveerride model to classify pictures
class SequentialModel(tf.keras.Model):
  def __init__(self, **kwargs):
    super(SequentialModel, self).__init__(**kwargs)
    self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
    # Add a lot of small layers
    num_layers = 100
    self.my_layers = [tf.keras.layers.Dense(64, activation="relu")
                      for n in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(0.2)
    self.dense_2 = tf.keras.layers.Dense(10)

  def call(self, x):
    x = self.flatten(x)
    for layer in self.my_layers:
      x = layer(x)
    x = self.dropout(x)
    x = self.dense_2(x)
    return x
input_data = tf.random.uniform([20, 28, 28])
eager_model = SequentialModel()

# Don't count the time for the initial build.
print("Eager time:", timeit.timeit(lambda: eager_model(input_data), number=100))
Eager time: 2.185799148000001

# Wrap the call method in a `tf.function`
graph_model = SequentialModel() = tf.function(

# Don't count the time for the initial build and trace.
print("Graph time:", timeit.timeit(lambda: graph_model(input_data), number=100))
Graph time: 0.30396231500003523

Polymorphic functions

When you trace a function, you create a Function object that is polymorphic. A polymorphic function is a Python callable that encapsulates several concrete function graphs behind one API.

You can use this Function on all different kinds of dtypes and shapes. Each time you invoke it with a new argument signature, the original function gets re-traced with the new arguments. The Function then stores the tf.Graph corresponding to that trace in a concrete_function. If the function has already been traced with that kind of argument, you just get your pre-traced graph.

Conceptually, then:

  • A tf.Graph is the raw, portable data structure describing a computation
  • A Function is a caching, tracing, dispatcher over ConcreteFunctions
  • A ConcreteFunction is an eager-compatible wrapper around a graph that lets you execute the graph from Python

Inspecting polymorphic functions

You can inspect a_function, which is the result of calling tf.function on the Python function my_function. In this example, calling a_function with three kinds of arguments results in three different concrete functions.


print("Calling a `Function`:")
print("Int:", a_function(tf.constant(2)))
print("Float:", a_function(tf.constant(2.0)))
print("Rank-1 tensor of floats", a_function(tf.constant([2.0, 2.0, 2.0])))
<tensorflow.python.eager.def_function.Function object at 0x7f90cc28e4a8>
Calling a `Function`:
Int: tf.Tensor(1, shape=(), dtype=int32)
Float: tf.Tensor(1.0, shape=(), dtype=float32)
Rank-1 tensor of floats tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)

# Get the concrete function that works on floats
print("Inspecting concrete functions")
print("Concrete function for float:")
print(a_function.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.float32)))
print("Concrete function for tensor of floats:")
print(a_function.get_concrete_function(tf.constant([2.0, 2.0, 2.0])))
Inspecting concrete functions
Concrete function for float:
ConcreteFunction my_function(x)
    x: float32 Tensor, shape=()
    float32 Tensor, shape=()
Concrete function for tensor of floats:
ConcreteFunction my_function(x)
    x: float32 Tensor, shape=(3,)
    float32 Tensor, shape=(3,)

# Concrete functions are callable
# Note: You won't normally do this, but instead just call the containing `Function`
cf = a_function.get_concrete_function(tf.constant(2))
print("Directly calling a concrete function:", cf(tf.constant(2)))
Directly calling a concrete function: tf.Tensor(1, shape=(), dtype=int32)

In this example, you are seeing pretty far into the stack. Unless you are specifically managing tracing, you will not normally need to call concrete functions directly as shown here.

Reverting to eager execution

You may find yourself looking at long stack traces, specially ones that refer to tf.Graph or with tf.Graph().as_default(). This means you are likely running in a graph context. Core functions in TensorFlow use graph contexts, such as Keras's

It is often much easier to debug eager execution. Stack traces should be relatively short and easy to comprehend.

In situations where the graph makes debugging tricky, you can revert to using eager execution to debug.

Here are ways you can make sure you are running eagerly:

  • Call models and layers directly as callables

  • When using Keras compile/fit, at compile time use model.compile(run_eagerly=True)

  • Set global execution mode via tf.config.run_functions_eagerly(True)

Using run_eagerly=True

# Define an identity layer with an eager side effect
class EagerLayer(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(EagerLayer, self).__init__(**kwargs)
    # Do some kind of initialization here

  def call(self, inputs):
    print("\nCurrently running eagerly", str(
    return inputs
# Create an override model to classify pictures, adding the custom layer
class SequentialModel(tf.keras.Model):
  def __init__(self):
    super(SequentialModel, self).__init__()
    self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
    self.dense_1 = tf.keras.layers.Dense(128, activation="relu")
    self.dropout = tf.keras.layers.Dropout(0.2)
    self.dense_2 = tf.keras.layers.Dense(10)
    self.eager = EagerLayer()

  def call(self, x):
    x = self.flatten(x)
    x = self.dense_1(x)
    x = self.dropout(x)
    x = self.dense_2(x)
    return self.eager(x)

# Create an instance of this model
model = SequentialModel()

# Generate some nonsense pictures and labels
input_data = tf.random.uniform([60, 28, 28])
labels = tf.random.uniform([60])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

First, compile the model without eager. Note that the model is not traced; despite its name, compile only sets up loss functions, optimization, and other training parameters.

model.compile(run_eagerly=False, loss=loss_fn)

Now, call fit and see that the function is traced (twice) and then the eager effect never runs again., labels, epochs=3)
Epoch 1/3

Currently running eagerly 2021-01-13 02:25:36.809205

Currently running eagerly 2021-01-13 02:25:36.941800
2/2 [==============================] - 0s 3ms/step - loss: 2.0352
Epoch 2/3
2/2 [==============================] - 0s 3ms/step - loss: 0.0045
Epoch 3/3
2/2 [==============================] - 0s 2ms/step - loss: 0.0026

<tensorflow.python.keras.callbacks.History at 0x7f90102f2550>

If you run even a single epoch in eager, however, you can see the eager side effect twice.

print("Running eagerly")
# When compiling the model, set it to run eagerly
model.compile(run_eagerly=True, loss=loss_fn), labels, epochs=1)
Running eagerly

Currently running eagerly 2021-01-13 02:25:37.173159
1/2 [==============>...............] - ETA: 0s - loss: 0.0023
Currently running eagerly 2021-01-13 02:25:37.195392
2/2 [==============================] - 0s 13ms/step - loss: 0.0016

<tensorflow.python.keras.callbacks.History at 0x7f90101981d0>

Using run_functions_eagerly

You can also globally set everything to run eagerly. This is a switch that bypasses the polymorphic function's traced functions and calls the original function directly. You can use this for debugging.

# Now, globally set everything to run eagerly
print("Run all functions eagerly.")

# Create a polymorphic function
polymorphic_function = tf.function(model)

# This does, in fact, trace the function

print("\nCalling twice eagerly")
# When you run the function again, you will see the side effect
# twice, as the function is running eagerly.
result = polymorphic_function(input_data)
result = polymorphic_function(input_data)
Run all functions eagerly.

Currently running eagerly 2021-01-13 02:25:37.594444
ConcreteFunction function(self)
    self: float32 Tensor, shape=(60, 28, 28)
    float32 Tensor, shape=(60, 10)

Calling twice eagerly

Currently running eagerly 2021-01-13 02:25:37.600183

Currently running eagerly 2021-01-13 02:25:37.602196

# Don't forget to set it back when you are done
WARNING:tensorflow:From <ipython-input-1-782fe9ce7b18>:2: experimental_run_functions_eagerly (from tensorflow.python.eager.def_function) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.run_functions_eagerly` instead of the experimental version.

Tracing and performance

Tracing costs some overhead. Although tracing small functions is quick, large models can take noticeable wall-clock time to trace. This investment is usually quickly paid back with a performance boost, but it's important to be aware that the first few epochs of any large model training can be slower due to tracing.

No matter how large your model, you want to avoid tracing frequently. This section of the tf.function guide discusses how to set input specifications and use tensor arguments to avoid retracing. If you find you are getting unusually poor performance, it's good to check to see if you are retracing accidentally.

You can add an eager-only side effect (such as printing a Python argument) so you can see when the function is being traced. Here, you see extra retracing because new Python arguments always trigger retracing.

# Use @tf.function decorator
def a_function_with_python_side_effect(x):
  print("Tracing!")  # This eager
  return x * x + tf.constant(2)

# This is traced the first time
# The second time through, you won't see the side effect

# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)

Next steps

You can read a more in-depth discussion at both the tf.function API reference page and at the guide.