tf.function

Compiles a function into a callable TensorFlow graph. (deprecated arguments)

Used in the notebooks

Used in the guide Used in the tutorials

tf.function constructs a tf.types.experimental.GenericFunction that executes a TensorFlow graph (tf.Graph) created by trace-compiling the TensorFlow operations in func.

Example usage:

@tf.function
def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>

The trace-compilation allows non-TensorFlow operations to execute, but under special conditions. In general, only TensorFlow operations are guaranteed to run and create fresh results whenever the GenericFunction is called.

Features

func may use data-dependent control flow, including if, for, while break, continue and return statements:

@tf.function
def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
  else:
    return -x // 2
f(tf.constant(-2))
<tf.Tensor: ... numpy=1>

func's closure may include tf.Tensor and tf.Variable objects:

@tf.function
def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor: ... numpy=array([7, 7], ...)>

func may also use ops with side effects, such as tf.print, tf.Variable and others:

v = tf.Variable(1)
@tf.function
def f(x):
  for i in tf.range(x):
    v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
l = []
@tf.function
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]

Instead, use TensorFlow collections like tf.TensorArray:

@tf.function
def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>

tf.function creates polymorphic callables

Internally, tf.types.experimental.GenericFunction may contain multiple tf.types.experimental.ConcreteFunctions, each specialized to arguments with different data types or shapes, since TensorFlow can perform more optimizations on graphs of specific shapes, dtypes and values of constant arguments. tf.function treats any pure Python values as opaque objects (best thought of as compile-time constants), and builds a separate tf.Graph for each set of Python arguments that it encounters. For more information, see the tf.function guide

Executing a GenericFunction will select and execute the appropriate ConcreteFunction based on the argument types and values.

To obtain an individual ConcreteFunction, use the GenericFunction.get_concrete_function method. It can be called with the same arguments as func and returns a tf.types.experimental.ConcreteFunction. ConcreteFunctions are backed by a single tf.Graph: