|View on TensorFlow.org||Run in Google Colab||View source on GitHub||Download notebook|
This guide goes beneath the surface of TensorFlow and Keras to see how TensorFlow works. If you instead want to immediately get started with Keras, please see our collection of Keras guides.
In this guide you'll see the core of how TensorFlow allows you to make simple changes to your code to get graphs, and how they are stored and represented, and how you can use them to accelerate and export your models.
This is a short-form introduction; for a full introduction to these concepts, see the
What are graphs?
In the previous three guides, you have seen TensorFlow running eagerly. This means TensorFlow operations are executed by Python, operation by operation, and returning results back to Python. Eager TensorFlow takes advantage of GPUs, allowing you to place variables, tensors, and even operations on GPUs and TPUs. It is also easy to debug.
For some users, you may never need or want to leave Python.
However, running TensorFlow op-by-op in Python prevents a host of accelerations otherwise available. If you can extract tensor computations from Python, you can make them into a graph.
Graphs are data structures that contain a set of
tf.Operation objects, which represent units of computation; and
tf.Tensor objects, which represent the units of data that flow between operations. They are defined in a
tf.Graph context. Since these graphs are data structures, they can be saved, run, and restored all without the original Python code.
This is what a simple two-layer graph looks like when visualized in TensorBoard.
The benefits of graphs
With a graph, you have a great deal of flexibility. You can use your TensorFlow graph in environments that don't have a Python interpreter, like mobile applications, embedded devices, and backend servers. TensorFlow uses graphs as the format for saved models when it exports them from Python.
Graphs are also easily optimized, allowing the compiler to do transformations like:
- Statically infer the value of tensors by folding constant nodes in your computation ("constant folding").
- Separate sub-parts of a computation that are independent and split them between threads or devices.
- Simplify arithmetic operations by eliminating common subexpressions.
There is an entire optimization system, Grappler, to perform this and other speedups.
In short, graphs are extremely useful and let your TensorFlow run fast, run in parallel, and run efficiently on multiple devices.
However, you still want to define our machine learning models (or other computations) in Python for convenience, and then automatically construct graphs when you need them.
The way you create a graph in TensorFlow is to use
tf.function, either as a direct call or as a decorator.
import tensorflow as tf import timeit from datetime import datetime
# Define a Python function def function_to_get_faster(x, y, b): x = tf.matmul(x, y) x = x + b return x # Create a `Function` object that contains a graph a_function_that_uses_a_graph = tf.function(function_to_get_faster) # Make some tensors x1 = tf.constant([[1.0, 2.0]]) y1 = tf.constant([[2.0], [3.0]]) b1 = tf.constant(4.0) # It just works! a_function_that_uses_a_graph(x1, y1, b1).numpy()
tf.function-ized functions are Python callables that work the same as their Python equivalents. They have a particular class (
python.eager.def_function.Function), but to you they act just as the non-traced version.
tf.function recursively traces any Python function it calls.
def inner_function(x, y, b): x = tf.matmul(x, y) x = x + b return x # Use the decorator @tf.function def outer_function(x): y = tf.constant([[2.0], [3.0]]) b = tf.constant(4.0) return inner_function(x, y, b) # Note that the callable will create a graph that # includes inner_function() as well as outer_function() outer_function(tf.constant([[1.0, 2.0]])).numpy()
If you have used TensorFlow 1.x, you will notice that at no time did you need to define a
Flow control and side effects
def my_function(x): if tf.reduce_sum(x) <= 1: return x * x else: return x-1 a_function = tf.function(my_function) print("First branch, with graph:", a_function(tf.constant(1.0)).numpy()) print("Second branch, with graph:", a_function(tf.constant([5.0, 5.0])).numpy())
First branch, with graph: 1.0 Second branch, with graph: [4. 4.]
You can directly call the Autograph conversion to see how Python is converted into TensorFlow ops. This is, mostly, unreadable, but you can see the transformation.
# Don't read the output too carefully. print(tf.autograph.to_code(my_function))
def tf__my_function(x): with ag__.FunctionScope('my_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (retval_, do_return) def set_state(vars_): nonlocal retval_, do_return (retval_, do_return) = vars_ def if_body(): nonlocal retval_, do_return try: do_return = True retval_ = (ag__.ld(x) * ag__.ld(x)) except: do_return = False raise def else_body(): nonlocal retval_, do_return try: do_return = True retval_ = (ag__.ld(x) - 1) except: do_return = False raise ag__.if_stmt((ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) <= 1), if_body, else_body, get_state, set_state, ('retval_', 'do_return'), 2) return fscope.ret(retval_, do_return)
Autograph automatically converts
if-then clauses, loops,
continue, and more.
Seeing the speed up
Just wrapping a tensor-using function in
tf.function does not automatically speed up your code. For small functions called a few times on a single machine, the overhead of calling a graph or graph fragment may dominate runtime. Also, if most of the computation was already happening on an accelerator, such as stacks of GPU-heavy convolutions, the graph speedup won't be large.
For complicated computations, graphs can provide a significant speedup. This is because graphs reduce the Python-to-device communication and perform some speedups.
This code times a few runs on some small dense layers.
# Create an oveerride model to classify pictures class SequentialModel(tf.keras.Model): def __init__(self, **kwargs): super(SequentialModel, self).__init__(**kwargs) self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28)) self.dense_1 = tf.keras.layers.Dense(128, activation="relu") self.dropout = tf.keras.layers.Dropout(0.2) self.dense_2 = tf.keras.layers.Dense(10) def call(self, x): x = self.flatten(x) x = self.dense_1(x) x = self.dropout(x) x = self.dense_2(x) return x input_data = tf.random.uniform([60, 28, 28]) eager_model = SequentialModel() graph_model = tf.function(eager_model) print("Eager time:", timeit.timeit(lambda: eager_model(input_data), number=10000)) print("Graph time:", timeit.timeit(lambda: graph_model(input_data), number=10000))
Eager time: 5.330957799000316 Graph time: 3.0375442899999143
When you trace a function, you create a
Function object that is polymorphic. A polymorphic function is a Python callable that encapsulates several concrete function graphs behind one API.
You can use this
Function on all different kinds of
dtypes and shapes. Each time you invoke it with a new argument signature, the original function gets re-traced with the new arguments. The
Function then stores the
tf.Graph corresponding to that trace in a
concrete_function. If the function has already been traced with that kind of argument, you just get your pre-traced graph.
tf.Graphis the raw, portable data structure describing a computation
Functionis a caching, tracing, dispatcher over ConcreteFunctions
ConcreteFunctionis an eager-compatible wrapper around a graph that lets you execute the graph from Python
Inspecting polymorphic functions
You can inspect
a_function, which is the result of calling
tf.function on the Python function
my_function. In this example, calling
a_function with three kinds of arguments results in three different concrete functions.
print(a_function) print("Calling a `Function`:") print("Int:", a_function(tf.constant(2))) print("Float:", a_function(tf.constant(2.0))) print("Rank-1 tensor of floats", a_function(tf.constant([2.0, 2.0, 2.0])))
<tensorflow.python.eager.def_function.Function object at 0x7f7d342602b0> Calling a `Function`: Int: tf.Tensor(1, shape=(), dtype=int32) Float: tf.Tensor(1.0, shape=(), dtype=float32) Rank-1 tensor of floats tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)
# Get the concrete function that works on floats print("Inspecting concrete functions") print("Concrete function for float:") print(a_function.get_concrete_function(tf.TensorSpec(shape=, dtype=tf.float32))) print("Concrete function for tensor of floats:") print(a_function.get_concrete_function(tf.constant([2.0, 2.0, 2.0])))
Inspecting concrete functions Concrete function for float: ConcreteFunction my_function(x) Args: x: float32 Tensor, shape=() Returns: float32 Tensor, shape=() Concrete function for tensor of floats: ConcreteFunction my_function(x) Args: x: float32 Tensor, shape=(3,) Returns: float32 Tensor, shape=(3,)
# Concrete functions are callable # Note: You won't normally do this, but instead just call the containing `Function` cf = a_function.get_concrete_function(tf.constant(2)) print("Directly calling a concrete function:", cf(tf.constant(2)))
Directly calling a concrete function: tf.Tensor(1, shape=(), dtype=int32)
In this example, you are seeing pretty far into the stack. Unless you are specifically managing tracing, you will not normally need to call concrete functions directly as shown here.
Reverting to eager execution
You may find yourself looking at long stack traces, specially ones that refer to
with tf.Graph().as_default(). This means you are likely running in a graph context. Core functions in TensorFlow use graph contexts, such as Keras's
It is often much easier to debug eager execution. Stack traces should be relatively short and easy to comprehend.
In situations where the graph makes debugging tricky, you can revert to using eager execution to debug.
Here are ways you can make sure you are running eagerly:
Call models and layers directly as callables
When using Keras compile/fit, at compile time use
Set global execution mode via
# Define an identity layer with an eager side effect class EagerLayer(tf.keras.layers.Layer): def __init__(self, **kwargs): super(EagerLayer, self).__init__(**kwargs) # Do some kind of initialization here def call(self, inputs): print("\nCurrently running eagerly", str(datetime.now())) return inputs
# Create an override model to classify pictures, adding the custom layer class SequentialModel(tf.keras.Model): def __init__(self): super(SequentialModel, self).__init__() self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28)) self.dense_1 = tf.keras.layers.Dense(128, activation="relu") self.dropout = tf.keras.layers.Dropout(0.2) self.dense_2 = tf.keras.layers.Dense(10) self.eager = EagerLayer() def call(self, x): x = self.flatten(x) x = self.dense_1(x) x = self.dropout(x) x = self.dense_2(x) return self.eager(x) # Create an instance of this model model = SequentialModel() # Generate some nonsense pictures and labels input_data = tf.random.uniform([60, 28, 28]) labels = tf.random.uniform() loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
First, compile the model without eager. Note that the model is not traced; despite its name,
compile only sets up loss functions, optimization, and other training parameters.
fit and see that the function is traced (twice) and then the eager effect never runs again.
model.fit(input_data, labels, epochs=3)
Epoch 1/3 Currently running eagerly 2020-11-04 02:22:28.114630 Currently running eagerly 2020-11-04 02:22:28.233822 2/2 [==============================] - 0s 2ms/step - loss: 0.9890 Epoch 2/3 2/2 [==============================] - 0s 1ms/step - loss: 0.0017 Epoch 3/3 2/2 [==============================] - 0s 1ms/step - loss: 5.4478e-04 <tensorflow.python.keras.callbacks.History at 0x7f7cb02b5e80>
If you run even a single epoch in eager, however, you can see the eager side effect twice.
print("Running eagerly") # When compiling the model, set it to run eagerly model.compile(run_eagerly=True, loss=loss_fn) model.fit(input_data, labels, epochs=1)
Running eagerly Currently running eagerly 2020-11-04 02:22:28.449835 1/2 [==============>...............] - ETA: 0s - loss: 7.8712e-04 Currently running eagerly 2020-11-04 02:22:28.472645 2/2 [==============================] - 0s 6ms/step - loss: 4.1988e-04 <tensorflow.python.keras.callbacks.History at 0x7f7cb0230780>
You can also globally set everything to run eagerly. This is a switch that bypasses the polymorphic function's traced functions and calls the original function directly. You can use this for debugging.
# Now, globally set everything to run eagerly tf.config.run_functions_eagerly(True) print("Run all functions eagerly.") # Create a polymorphic function polymorphic_function = tf.function(model) print("Tracing") # This does, in fact, trace the function print(polymorphic_function.get_concrete_function(input_data)) print("\nCalling twice eagerly") # When you run the function again, you will see the side effect # twice, as the function is running eagerly. result = polymorphic_function(input_data) result = polymorphic_function(input_data)
Run all functions eagerly. Tracing Currently running eagerly 2020-11-04 02:22:28.502694 ConcreteFunction function(self) Args: self: float32 Tensor, shape=(60, 28, 28) Returns: float32 Tensor, shape=(60, 10) Calling twice eagerly Currently running eagerly 2020-11-04 02:22:28.507036 Currently running eagerly 2020-11-04 02:22:28.508331
# Don't forget to set it back when you are done tf.config.experimental_run_functions_eagerly(False)
WARNING:tensorflow:From <ipython-input-1-782fe9ce7b18>:2: experimental_run_functions_eagerly (from tensorflow.python.eager.def_function) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.config.run_functions_eagerly` instead of the experimental version.
Tracing and performance
Tracing costs some overhead. Although tracing small functions is quick, large models can take noticeable wall-clock time to trace. This investment is usually quickly paid back with a performance boost, but it's important to be aware that the first few epochs of any large model training can be slower due to tracing.
No matter how large your model, you want to avoid tracing frequently. This section of the tf.function guide discusses how to set input specifications and use tensor arguments to avoid retracing. If you find you are getting unusually poor performance, it's good to check to see if you are retracing accidentally.
You can add an eager-only side effect (such as printing a Python argument) so you can see when the function is being traced. Here, you see extra retracing because new Python arguments always trigger retracing.
# Use @tf.function decorator @tf.function def a_function_with_python_side_effect(x): print("Tracing!") # This eager return x * x + tf.constant(2) # This is traced the first time print(a_function_with_python_side_effect(tf.constant(2))) # The second time through, you won't see the side effect print(a_function_with_python_side_effect(tf.constant(3))) # This retraces each time the Python argument changes, # as a Python argument could be an epoch count or other # hyperparameter print(a_function_with_python_side_effect(2)) print(a_function_with_python_side_effect(3))
Tracing! tf.Tensor(6, shape=(), dtype=int32) tf.Tensor(11, shape=(), dtype=int32) Tracing! tf.Tensor(6, shape=(), dtype=int32) Tracing! tf.Tensor(11, shape=(), dtype=int32)