tf.function

Compiles a function into a callable TensorFlow graph.

Used in the notebooks

Used in the guide Used in the tutorials

tf.function constructs a callable that executes a TensorFlow graph (tf.Graph) created by trace-compiling the TensorFlow operations in func, effectively executing func as a TensorFlow graph.

Example usage:

@tf.function
def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>

Features

func may use data-dependent control flow, including if, for, while break, continue and return statements:

@tf.function
def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
  else:
    return -x // 2
f(tf.constant(-2))
<tf.Tensor: ... numpy=1>

func's closure may include tf.Tensor and tf.Variable objects:

@tf.function
def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor: ... numpy=array([7, 7], ...)>

func may also use ops with side effects, such as tf.print, tf.Variable and others:

v = tf.Variable(1)
@tf.function
def f(x):
  for i in tf.range(x):
    v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
l = []
@tf.function
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]

Instead, use TensorFlow collections like tf.TensorArray:

@tf.function
def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>

tf.function is polymorphic

Internally,