tf.function

Compiles a function into a callable TensorFlow graph.

Used in the notebooks

Used in the guide Used in the tutorials

tf.function constructs a callable that executes a TensorFlow graph (tf.Graph) created by trace-compiling the TensorFlow operations in func, effectively executing func as a TensorFlow graph.

Example usage:

@tf.function
def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>

Features

func may use data-dependent control flow, including if, for, while break, continue and return statements:

@tf.function
def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
  else:
    return -x // 2
f(tf.constant(-2))
<tf.Tensor: ... numpy=1>

func's closure may include tf.Tensor and tf.Variable objects:

@tf.function
def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor: ... numpy=array([7, 7], ...)>

func may also use ops with side effects, such as tf.print, tf.Variable and others:

v = tf.Variable(1)
@tf.function
def f(x):
  for i in tf.range(x):
    v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
l = []
@tf.function
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]

Instead, use TensorFlow collections like tf.TensorArray:

@tf.function
def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>

tf.function is polymorphic

Internally, tf.function can build more than one graph, to support arguments with different data types or shapes, since TensorFlow can build more efficient graphs that are specialized on shapes and dtypes. tf.function also treats any pure Python value as opaque objects, and builds a separate graph for each set of Python arguments that it encounters.

To obtain an individual graph, use the get_concrete_function method of the callable created by tf.function. It can be called with the same arguments as func and returns a special tf.Graph object:

@tf.function
def f(x):
  return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
@tf.function
def f(x):
  return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2)  # Slow - builds new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
f1 is f2
True

Python numerical arguments should only be used when they take few distinct values, such as hyperparameters like the number of layers in a neural network.

Input signatures

For Tensor arguments, tf.function instantiates a separate graph for every unique set of input shapes and datatypes. The example below creates two separate graphs, each specialized to a different shape:

@tf.function
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False

An "input signature" can be optionally provided to tf.function to control the graphs traced. The input signature specifies the shape and type of each Tensor argument to the function using a tf.TensorSpec object. More general shapes can be used. This is useful to avoid creating multiple graphs when Tensors have dynamic shapes. It also restricts the shape and datatype of Tensors that can be used: