![]() |
![]() |
![]() |
![]() |
TF 2.0 brings together the ease of eager execution and the power of TF 1.0. At the center of this merger is tf.function
, which allows you to transform a subset of Python syntax into portable, high-performance TensorFlow graphs.
A cool new feature of tf.function
is AutoGraph, which lets you write graph code using natural Python syntax. For a list of the Python features that you can use with AutoGraph, see AutoGraph Capabilities and Limitations. For more details about tf.function
, see the RFC TF 2.0: Functions, not Sessions. For more details about AutoGraph, see tf.autograph
.
This tutorial will walk you through the basic features of tf.function
and AutoGraph.
Setup
Import TensorFlow 2.0:
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import tensorflow as tf
The tf.function
decorator
When you annotate a function with tf.function
, you can still call it like any other function. But it will be compiled into a graph, which means you get the benefits of faster execution, running on GPU or TPU, or exporting to SavedModel.
@tf.function
def simple_nn_layer(x, y):
return tf.nn.relu(tf.matmul(x, y))
x = tf.random.uniform((3, 3))
y = tf.random.uniform((3, 3))
simple_nn_layer(x, y)
<tf.Tensor: id=23, shape=(3, 3), dtype=float32, numpy= array([[1.0051662 , 0.7925216 , 0.897657 ], [0.50381726, 0.6784547 , 0.75661516], [1.3040552 , 1.1737359 , 1.3159252 ]], dtype=float32)>
If we examine the result of the annotation, we can see that it's a special callable that handles all interactions with the TensorFlow runtime.
simple_nn_layer
<tensorflow.python.eager.def_function.Function at 0x7f03ea818470>
If your code uses multiple functions, you don't need to annotate them all - any functions called from an annotated function will also run in graph mode.
def linear_layer(x):
return 2 * x + 1
@tf.function
def deep_net(x):
return tf.nn.relu(linear_layer(x))
deep_net(tf.constant((1, 2, 3)))
<tf.Tensor: id=35, shape=(3,), dtype=int32, numpy=array([3, 5, 7], dtype=int32)>
Functions can be faster than eager code, for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.004500312000118356 Function conv: 0.0033003070000177104 Note how there's not much difference in performance for convolutions
lstm_cell = tf.keras.layers.LSTMCell(10)
@tf.function
def lstm_fn(input, state):
return lstm_cell(input, state)
input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# warm up
lstm_cell(input, state); lstm_fn(input, state)
print("eager lstm:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("function lstm:", timeit.timeit(lambda: lstm_fn(input, state), number=10))
eager lstm: 0.007719941999994262 function lstm: 0.0039423630000783305
Use Python control flow
When using data-dependent control flow inside tf.function
, you can use Python control flow statements and AutoGraph will convert them into appropriate TensorFlow ops. For example, if
statements will be converted into tf.cond()
if they depend on a Tensor
.
In the example below, x
is a Tensor
but the if
statement works as expected:
@tf.function
def square_if_positive(x):
if x > 0:
x = x * x
else:
x = 0
return x
print('square_if_positive(2) = {}'.format(square_if_positive(tf.constant(2))))
print('square_if_positive(-2) = {}'.format(square_if_positive(tf.constant(-2))))
square_if_positive(2) = 4 square_if_positive(-2) = 0
AutoGraph supports common Python statements like while
, for
, if
, break
, continue
and return
, with support for nesting. That means you can use Tensor
expressions in the condition of while
and if
statements, or iterate over a Tensor
in a for
loop.
@tf.function
def sum_even(items):
s = 0
for c in items:
if c % 2 > 0:
continue
s += c
return s
sum_even(tf.constant([10, 12, 15, 20]))
<tf.Tensor: id=602, shape=(), dtype=int32, numpy=42>
AutoGraph also provides a low-level API for advanced users. For example we can use it to have a look at the generated code.
print(tf.autograph.to_code(sum_even.python_function))
def tf__sum_even(items): do_return = False retval_ = ag__.UndefinedReturnValue() with ag__.FunctionScope('sum_even', 'sum_even_scope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as sum_even_scope: s = 0 def get_state_2(): return () def set_state_2(_): pass def loop_body(iterates, s): c = iterates continue_ = False def get_state(): return () def set_state(_): pass def if_true(): continue_ = True return continue_ def if_false(): return continue_ cond = c % 2 > 0 continue_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state, ('continue_',), ()) def get_state_1(): return () def set_state_1(_): pass def if_true_1(): s_1, = s, s_1 += c return s_1 def if_false_1(): return s cond_1 = ag__.not_(continue_) s = ag__.if_stmt(cond_1, if_true_1, if_false_1, get_state_1, set_state_1, ('s',), ()) return s, s, = ag__.for_stmt(items, None, loop_body, get_state_2, set_state_2, (s,), ('s',), ()) do_return = True retval_ = sum_even_scope.mark_return_value(s) do_return, return ag__.retval(retval_)
Here's an example of more complicated control flow:
@tf.function
def fizzbuzz(n):
for i in tf.range(n):
if i % 3 == 0:
tf.print('Fizz')
elif i % 5 == 0:
tf.print('Buzz')
else:
tf.print(i)
fizzbuzz(tf.constant(15))
Fizz 1 2 Fizz 4 Buzz Fizz 7 8 Fizz Buzz 11 Fizz 13 14
Keras and AutoGraph
AutoGraph is available by default in non-dynamic Keras models. For more information, see tf.keras
.
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
return input_data // 2
model = CustomModel()
model(tf.constant([-2, -4]))
<tf.Tensor: id=710, shape=(2,), dtype=int32, numpy=array([-1, -2], dtype=int32)>
Side effects
Just like in eager mode, you can use operations with side effects, like tf.assign
or tf.print
normally inside tf.function
, and it will insert the necessary control dependencies to ensure they execute in order.
v = tf.Variable(5)
@tf.function
def find_next_odd():
v.assign(v + 1)
if v % 2 == 0:
v.assign(v + 1)
find_next_odd()
v
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=7>
Debugging
tf.function
and AutoGraph work by generating code and tracing it into TensorFlow graphs. This mechanism does not yet support step-by-step debuggers like pdb
. However, you can call tf.config.run_functions_eagerly(True)
to temporarily enable eager execution inside the `tf.function' and use your favorite debugger:
@tf.function
def f(x):
if x > 0:
# Try setting a breakpoint here!
# Example:
# import pdb
# pdb.set_trace()
x = x + 1
return x
tf.config.experimental_run_functions_eagerly(True)
# You can now set breakpoints and run the code in a debugger.
f(tf.constant(1))
tf.config.experimental_run_functions_eagerly(False)
Advanced example: An in-graph training loop
The previous section showed that AutoGraph can be used inside Keras layers and models. Keras models can also be used in AutoGraph code.
This example shows how to train a simple Keras model on MNIST with the entire training process—loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence—is performed in-graph.
Download data
def prepare_mnist_features_and_labels(x, y):
x = tf.cast(x, tf.float32) / 255.0
y = tf.cast(y, tf.int64)
return x, y
def mnist_dataset():
(x, y), _ = tf.keras.datasets.mnist.load_data()
ds = tf.data.Dataset.from_tensor_slices((x, y))
ds = ds.map(prepare_mnist_features_and_labels)
ds = ds.take(20000).shuffle(20000).batch(100)
return ds
train_dataset = mnist_dataset()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step
Define the model
model = tf.keras.Sequential((
tf.keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(10)))
model.build()
optimizer = tf.keras.optimizers.Adam()
Define the training loop
compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
def train_one_step(model, optimizer, x, y):
with tf.GradientTape() as tape:
logits = model(x)
loss = compute_loss(y, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
compute_accuracy(y, logits)
return loss
@tf.function
def train(model, optimizer):
train_ds = mnist_dataset()
step = 0
loss = 0.0
accuracy = 0.0
for x, y in train_ds:
step += 1
loss = train_one_step(model, optimizer, x, y)
if step % 10 == 0:
tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
return step, loss, accuracy
step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
Step 10 : loss 1.92404079 ; accuracy 0.335 Step 20 : loss 1.12404513 ; accuracy 0.5175 Step 30 : loss 0.757334828 ; accuracy 0.614 Step 40 : loss 0.54961431 ; accuracy 0.668 Step 50 : loss 0.523098886 ; accuracy 0.7046 Step 60 : loss 0.544653177 ; accuracy 0.732166648 Step 70 : loss 0.260321438 ; accuracy 0.754857123 Step 80 : loss 0.293551743 ; accuracy 0.76975 Step 90 : loss 0.326000869 ; accuracy 0.780888915 Step 100 : loss 0.38042298 ; accuracy 0.7913 Step 110 : loss 0.193499714 ; accuracy 0.802909076 Step 120 : loss 0.407712549 ; accuracy 0.8115 Step 130 : loss 0.40341872 ; accuracy 0.818384588 Step 140 : loss 0.263050735 ; accuracy 0.823928595 Step 150 : loss 0.330198884 ; accuracy 0.8296 Step 160 : loss 0.331770599 ; accuracy 0.83456248 Step 170 : loss 0.198889062 ; accuracy 0.839941204 Step 180 : loss 0.219250724 ; accuracy 0.844 Step 190 : loss 0.161010623 ; accuracy 0.848157883 Step 200 : loss 0.153816774 ; accuracy 0.8522 Final step tf.Tensor(200, shape=(), dtype=int32) : loss tf.Tensor(0.15381677, shape=(), dtype=float32) ; accuracy tf.Tensor(0.8522, shape=(), dtype=float32)
Batching
In real applications batching is essential for performance. The best code to convert to AutoGraph is code where the control flow is decided at the batch level. If making decisions at the individual example level, try to use batch APIs to maintain performance.
For example, if you have the following code in Python:
def square_if_positive(x):
return [i ** 2 if i > 0 else i for i in x]
square_if_positive(range(-5, 5))
[-5, -4, -3, -2, -1, 0, 1, 4, 9, 16]
You may be tempted to write it in TensorFlow as such (and this would work!):
@tf.function
def square_if_positive_naive(x):
result = tf.TensorArray(tf.int32, size=x.shape[0])
for i in tf.range(x.shape[0]):
if x[i] > 0:
result = result.write(i, x[i] ** 2)
else:
result = result.write(i, x[i])
return result.stack()
square_if_positive_naive(tf.range(-5, 5))
<tf.Tensor: id=1660, shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1, 0, 1, 4, 9, 16], dtype=int32)>
But in this case, it turns out you can write the following:
def square_if_positive_vectorized(x):
return tf.where(x > 0, x ** 2, x)
square_if_positive_vectorized(tf.range(-5, 5))
<tf.Tensor: id=1669, shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1, 0, 1, 4, 9, 16], dtype=int32)>
Re-tracing
Key points:
- Exercise caution when calling functions with non-tensor arguments, or with arguments that change shapes.
- Decorate module-level functions, and methods of module-level classes, and avoid decorating local functions or methods.
tf.function
can give you significant speedup over eager execution, at the cost of a slower first-time execution. This is because when executed for the first time, the function is also traced into a TensorFlow graph. Constructing and optimizing a graph is usually much slower compared to actually executing it:
import timeit
@tf.function
def f(x, y):
return tf.matmul(x, y)
print(
"First invocation:",
timeit.timeit(lambda: f(tf.ones((10, 10)), tf.ones((10, 10))), number=1))
print(
"Second invocation:",
timeit.timeit(lambda: f(tf.ones((10, 10)), tf.ones((10, 10))), number=1))
First invocation: 0.05051725399994211 Second invocation: 0.0011011389999566745
You can easily tell when a function is traced by adding a print
statement to the top of the function. Because any Python code is only executed at trace time, you will only see the otput of print
when the function is traced:
@tf.function
def f():
print('Tracing!')
tf.print('Executing')
print('First invocation:')
f()
print('Second invocation:')
f()
First invocation: Tracing! Executing Second invocation: Executing
tf.function
may also re-trace when called with different non-tensor arguments:
@tf.function
def f(n):
print(n, 'Tracing!')
tf.print(n, 'Executing')
f(1)
f(1)
f(2)
f(2)
1 Tracing! 1 Executing 1 Executing 2 Tracing! 2 Executing 2 Executing
A re-trace can also happen when tensor arguments change shape, unless you specified an input_signature
:
@tf.function
def f(x):
print(x.shape, 'Tracing!')
tf.print(x, 'Executing')
f(tf.constant([1]))
f(tf.constant([2]))
f(tf.constant([1, 2]))
f(tf.constant([3, 4]))
(1,) Tracing! [1] Executing [2] Executing (2,) Tracing! [1 2] Executing [3 4] Executing
In addition, tf.function always creates a new graph function with its own set of traces whenever it is called:
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
This can lead to surprising behavior when using the @tf.function
decorator in a nested function:
def outer():
@tf.function
def f():
print('Tracing!')
tf.print('Executing')
f()
outer()
outer()
Tracing! Executing Tracing! Executing