View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |

# Introduction to Graphs and `tf.function`

This guide goes beneath the surface of TensorFlow and Keras to see how TensorFlow works. If you instead want to immediately get started with Keras, please see our collection of Keras guides.

In this guide you'll see the core of how TensorFlow allows you to make simple changes to your code to get graphs, and how they are stored and represented, and how you can use them to accelerate and export your models.

This is a short-form introduction; for a full introduction to these concepts, see the `tf.function`

guide.

## What are graphs?

In the previous three guides, you have seen TensorFlow running **eagerly**. This means TensorFlow operations are executed by Python, operation by operation, and returning results back to Python. Eager TensorFlow takes advantage of GPUs, allowing you to place variables, tensors, and even operations on GPUs and TPUs. It is also easy to debug.

For some users, you may never need or want to leave Python.

However, running TensorFlow op-by-op in Python prevents a host of accelerations otherwise available. If you can extract tensor computations from Python, you can make them into a *graph*.

**Graphs are data structures that contain a set of tf.Operation objects, which represent units of computation; and tf.Tensor objects, which represent the units of data that flow between operations.** They are defined in a

`tf.Graph`

context. Since these graphs are data structures, they can be saved, run, and restored all without the original Python code.This is what a simple two-layer graph looks like when visualized in TensorBoard.

## The benefits of graphs

With a graph, you have a great deal of flexibility. You can use your TensorFlow graph in environments that don't have a Python interpreter, like mobile applications, embedded devices, and backend servers. TensorFlow uses graphs as the format for saved models when it exports them from Python.

Graphs are also easily optimized, allowing the compiler to do transformations like:

- Statically infer the value of tensors by folding constant nodes in your computation
*("constant folding")*. - Separate sub-parts of a computation that are independent and split them between threads or devices.
- Simplify arithmetic operations by eliminating common subexpressions.

There is an entire optimization system, Grappler, to perform this and other speedups.

In short, graphs are extremely useful and let your TensorFlow run **fast**, run **in parallel**, and run efficiently **on multiple devices**.

However, you still want to define our machine learning models (or other computations) in Python for convenience, and then automatically construct graphs when you need them.

# Tracing graphs

The way you create a graph in TensorFlow is to use `tf.function`

, either as a direct call or as a decorator.

```
import tensorflow as tf
import timeit
from datetime import datetime
```

```
# Define a Python function
def function_to_get_faster(x, y, b):
x = tf.matmul(x, y)
x = x + b
return x
# Create a `Function` object that contains a graph
a_function_that_uses_a_graph = tf.function(function_to_get_faster)
# Make some tensors
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)
# It just works!
a_function_that_uses_a_graph(x1, y1, b1).numpy()
```

array([[12.]], dtype=float32)

`tf.function`

-ized functions are Python callables that work the same as their Python equivalents. They have a particular class (`python.eager.def_function.Function`

), but to you they act just as the non-traced version.

`tf.function`

recursively traces any Python function it calls.

```
def inner_function(x, y, b):
x = tf.matmul(x, y)
x = x + b
return x
# Use the decorator
@tf.function
def outer_function(x):
y = tf.constant([[2.0], [3.0]])
b = tf.constant(4.0)
return inner_function(x, y, b)
# Note that the callable will create a graph that
# includes inner_function() as well as outer_function()
outer_function(tf.constant([[1.0, 2.0]])).numpy()
```

array([[12.]], dtype=float32)

If you have used TensorFlow 1.x, you will notice that at no time did you need to define a `Placeholder`

or `tf.Sesssion`

.

## Flow control and side effects

Flow control and loops are converted to TensorFlow via `tf.autograph`

by default. Autograph uses a combination of methods, including standardizing loop constructs, unrolling, and AST manipulation.

```
def my_function(x):
if tf.reduce_sum(x) <= 1:
return x * x
else:
return x-1
a_function = tf.function(my_function)
print("First branch, with graph:", a_function(tf.constant(1.0)).numpy())
print("Second branch, with graph:", a_function(tf.constant([5.0, 5.0])).numpy())
```

First branch, with graph: 1.0 Second branch, with graph: [4. 4.]

You can directly call the Autograph conversion to see how Python is converted into TensorFlow ops. This is, mostly, unreadable, but you can see the transformation.

```
# Don't read the output too carefully.
print(tf.autograph.to_code(my_function))
```

def tf__my_function(x): with ag__.FunctionScope('my_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (retval_, do_return) def set_state(vars_): nonlocal retval_, do_return (retval_, do_return) = vars_ def if_body(): nonlocal retval_, do_return try: do_return = True retval_ = (ag__.ld(x) * ag__.ld(x)) except: do_return = False raise def else_body(): nonlocal retval_, do_return try: do_return = True retval_ = (ag__.ld(x) - 1) except: do_return = False raise ag__.if_stmt((ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) <= 1), if_body, else_body, get_state, set_state, ('retval_', 'do_return'), 2) return fscope.ret(retval_, do_return)

Autograph automatically converts `if-then`

clauses, loops, `break`

, `return`

, `continue`

, and more.

Most of the time, Autograph will work without special considerations. However, there are some caveats, and the tf.function guide can help here, as well as the complete autograph reference

## Seeing the speed up

Just wrapping a tensor-using function in `tf.function`

does not automatically speed up your code. For small functions called a few times on a single machine, the overhead of calling a graph or graph fragment may dominate runtime. Also, if most of the computation was already happening on an accelerator, such as stacks of GPU-heavy convolutions, the graph speedup won't be large.

For complicated computations, graphs can provide a significant speedup. This is because graphs reduce the Python-to-device communication and perform some speedups.

This code times a few runs on some small dense layers.

```
# Create an oveerride model to classify pictures
class SequentialModel(tf.keras.Model):
def __init__(self, **kwargs):
super(SequentialModel, self).__init__(**kwargs)
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
self.dense_1 = tf.keras.layers.Dense(128, activation="relu")
self.dropout = tf.keras.layers.Dropout(0.2)
self.dense_2 = tf.keras.layers.Dense(10)
def call(self, x):
x = self.flatten(x)
x = self.dense_1(x)
x = self.dropout(x)
x = self.dense_2(x)
return x
input_data = tf.random.uniform([60, 28, 28])
eager_model = SequentialModel()
graph_model = tf.function(eager_model)
print("Eager time:", timeit.timeit(lambda: eager_model(input_data), number=10000))
print("Graph time:", timeit.timeit(lambda: graph_model(input_data), number=10000))
```

Eager time: 5.302399797999897 Graph time: 2.3688509589999285

### Polymorphic functions

When you trace a function, you create a `Function`

object that is **polymorphic**. A polymorphic function is a Python callable that encapsulates several concrete function graphs behind one API.

You can use this `Function`

on all different kinds of `dtypes`

and shapes. Each time you invoke it with a new argument signature, the original function gets re-traced with the new arguments. The `Function`

then stores the `tf.Graph`

corresponding to that trace in a `concrete_function`

. If the function has already been traced with that kind of argument, you just get your pre-traced graph.

Conceptually, then:

- A
is the raw, portable data structure describing a computation`tf.Graph`

- A
is a caching, tracing, dispatcher over ConcreteFunctions`Function`

- A
is an eager-compatible wrapper around a graph that lets you execute the graph from Python`ConcreteFunction`

### Inspecting polymorphic functions

You can inspect `a_function`

, which is the result of calling `tf.function`

on the Python function `my_function`

. In this example, calling `a_function`

with three kinds of arguments results in three different concrete functions.

```
print(a_function)
print("Calling a `Function`:")
print("Int:", a_function(tf.constant(2)))
print("Float:", a_function(tf.constant(2.0)))
print("Rank-1 tensor of floats", a_function(tf.constant([2.0, 2.0, 2.0])))
```

<tensorflow.python.eager.def_function.Function object at 0x7fe28ce18cf8> Calling a `Function`: Int: tf.Tensor(1, shape=(), dtype=int32) Float: tf.Tensor(1.0, shape=(), dtype=float32) Rank-1 tensor of floats tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)

```
# Get the concrete function that works on floats
print("Inspecting concrete functions")
print("Concrete function for float:")
print(a_function.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.float32)))
print("Concrete function for tensor of floats:")
print(a_function.get_concrete_function(tf.constant([2.0, 2.0, 2.0])))
```

Inspecting concrete functions Concrete function for float: ConcreteFunction my_function(x) Args: x: float32 Tensor, shape=() Returns: float32 Tensor, shape=() Concrete function for tensor of floats: ConcreteFunction my_function(x) Args: x: float32 Tensor, shape=(3,) Returns: float32 Tensor, shape=(3,)

```
# Concrete functions are callable
# Note: You won't normally do this, but instead just call the containing `Function`
cf = a_function.get_concrete_function(tf.constant(2))
print("Directly calling a concrete function:", cf(tf.constant(2)))
```

Directly calling a concrete function: tf.Tensor(1, shape=(), dtype=int32)

In this example, you are seeing pretty far into the stack. Unless you are specifically managing tracing, you will not normally need to call concrete functions directly as shown here.

# Reverting to eager execution

You may find yourself looking at long stack traces, specially ones that refer to `tf.Graph`

or `with tf.Graph().as_default()`

. This means you are likely running in a graph context. Core functions in TensorFlow use graph contexts, such as Keras's `model.fit()`

.

It is often much easier to debug eager execution. Stack traces should be relatively short and easy to comprehend.

In situations where the graph makes debugging tricky, you can revert to using eager execution to debug.

Here are ways you can make sure you are running eagerly:

Call models and layers directly as callables

When using Keras compile/fit, at compile time use

`model.compile(run_eagerly=True)`

Set global execution mode via

`tf.config.run_functions_eagerly(True)`

### Using `run_eagerly=True`

```
# Define an identity layer with an eager side effect
class EagerLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(EagerLayer, self).__init__(**kwargs)
# Do some kind of initialization here
def call(self, inputs):
print("\nCurrently running eagerly", str(datetime.now()))
return inputs
```

```
# Create an override model to classify pictures, adding the custom layer
class SequentialModel(tf.keras.Model):
def __init__(self):
super(SequentialModel, self).__init__()
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
self.dense_1 = tf.keras.layers.Dense(128, activation="relu")
self.dropout = tf.keras.layers.Dropout(0.2)
self.dense_2 = tf.keras.layers.Dense(10)
self.eager = EagerLayer()
def call(self, x):
x = self.flatten(x)
x = self.dense_1(x)
x = self.dropout(x)
x = self.dense_2(x)
return self.eager(x)
# Create an instance of this model
model = SequentialModel()
# Generate some nonsense pictures and labels
input_data = tf.random.uniform([60, 28, 28])
labels = tf.random.uniform([60])
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
```

First, compile the model without eager. Note that the model is not traced; despite its name, `compile`

only sets up loss functions, optimization, and other training parameters.

```
model.compile(run_eagerly=False, loss=loss_fn)
```

Now, call `fit`

and see that the function is traced (twice) and then the eager effect never runs again.

```
model.fit(input_data, labels, epochs=3)
```

Epoch 1/3 Currently running eagerly 2020-09-30 01:22:01.861126 Currently running eagerly 2020-09-30 01:22:01.982367 2/2 [==============================] - 0s 2ms/step - loss: 1.5409 Epoch 2/3 2/2 [==============================] - 0s 1ms/step - loss: 0.0021 Epoch 3/3 2/2 [==============================] - 0s 1ms/step - loss: 0.0013 <tensorflow.python.keras.callbacks.History at 0x7fe20c051940>

If you run even a single epoch in eager, however, you can see the eager side effect twice.

```
print("Running eagerly")
# When compiling the model, set it to run eagerly
model.compile(run_eagerly=True, loss=loss_fn)
model.fit(input_data, labels, epochs=1)
```

Running eagerly Currently running eagerly 2020-09-30 01:22:02.196963 1/2 [==============>...............] - ETA: 0s - loss: 9.7674e-04 Currently running eagerly 2020-09-30 01:22:02.220018 2/2 [==============================] - 0s 6ms/step - loss: 5.2117e-04 <tensorflow.python.keras.callbacks.History at 0x7fe200201278>

### Using `run_functions_eagerly`

You can also globally set everything to run eagerly. This is a switch that bypasses the polymorphic function's traced functions and calls the original function directly. You can use this for debugging.

```
# Now, globally set everything to run eagerly
tf.config.run_functions_eagerly(True)
print("Run all functions eagerly.")
# Create a polymorphic function
polymorphic_function = tf.function(model)
print("Tracing")
# This does, in fact, trace the function
print(polymorphic_function.get_concrete_function(input_data))
print("\nCalling twice eagerly")
# When you run the function again, you will see the side effect
# twice, as the function is running eagerly.
result = polymorphic_function(input_data)
result = polymorphic_function(input_data)
```

Run all functions eagerly. Tracing Currently running eagerly 2020-09-30 01:22:02.249703 ConcreteFunction function(self) Args: self: float32 Tensor, shape=(60, 28, 28) Returns: float32 Tensor, shape=(60, 10) Calling twice eagerly Currently running eagerly 2020-09-30 01:22:02.254759 Currently running eagerly 2020-09-30 01:22:02.256224

```
# Don't forget to set it back when you are done
tf.config.experimental_run_functions_eagerly(False)
```

WARNING:tensorflow:From <ipython-input-1-782fe9ce7b18>:2: experimental_run_functions_eagerly (from tensorflow.python.eager.def_function) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.config.run_functions_eagerly` instead of the experimental version.

# Tracing and performance

Tracing costs some overhead. Although tracing small functions is quick, large models can take noticeable wall-clock time to trace. This investment is usually quickly paid back with a performance boost, but it's important to be aware that the first few epochs of any large model training can be slower due to tracing.

No matter how large your model, you want to avoid tracing frequently. This section of the tf.function guide discusses how to set input specifications and use tensor arguments to avoid retracing. If you find you are getting unusually poor performance, it's good to check to see if you are retracing accidentally.

You can add an eager-only side effect (such as printing a Python argument) so you can see when the function is being traced. Here, you see extra retracing because new Python arguments always trigger retracing.

```
# Use @tf.function decorator
@tf.function
def a_function_with_python_side_effect(x):
print("Tracing!") # This eager
return x * x + tf.constant(2)
# This is traced the first time
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect
print(a_function_with_python_side_effect(tf.constant(3)))
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
```

Tracing! tf.Tensor(6, shape=(), dtype=int32) tf.Tensor(11, shape=(), dtype=int32) Tracing! tf.Tensor(6, shape=(), dtype=int32) Tracing! tf.Tensor(11, shape=(), dtype=int32)

# Next steps

You can read a more in-depth discussion at both the `tf.function`

API reference page and at the guide.