Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge


TensorFlow 1 version View source on GitHub

Compiles a function into a callable TensorFlow graph.

tf.function constructs a callable that executes a TensorFlow graph (tf.Graph) created by trace-compiling the TensorFlow operations in func, effectively executing func as a TensorFlow graph.

Example usage:

def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>


func may use data-dependent control flow, including if, for, while break, continue and return statements:

def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
    return -x // 2
<tf.Tensor: ... numpy=1>

func's closure may include tf.Tensor and tf.Variable objects:

def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
<tf.Tensor: ... numpy=array([7, 7], ...)>

func may also use ops with side effects, such as tf.print, tf.Variable and others:

v = tf.Variable(1)
def f(x):
  for i in tf.range(x):
<tf.Variable ... numpy=4>
l = []
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
[<tf.Tensor ...>]

Instead, use TensorFlow collections like tf.TensorArray:

def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>

tf.function is polymorphic

Internally, tf.function can build more than one graph, to support arguments with different data types or shapes, since TensorFlow can build more efficient graphs that are specialized on shapes and dtypes. tf.function also treats any pure Python value as opaque objects, and builds a separate graph for each set of Python arguments that it encounters.

To obtain an individual graph, use the get_concrete_function method of the callable created by tf.function. It can be called with the same arguments as func and returns a special tf.Graph object:

def f(x):
  return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)