Defined in tensorflow/python/eager/

Compiles a Python function into a callable TensorFlow graph.

defun (short for "define function") trace-compiles a Python function composed of TensorFlow operations into a callable that executes a tf.Graph containing those operations. The callable produced by defun contains only the subgraph of TensorFlow operations that were executed when the Python function was called with a particular input signature, defined as a list of the shapes and dtypes of the Python function's Tensor-valued arguments and the values of its non-Tensor Python objects. In particular, defun is not a compiler for arbitrary Python code.

When eager execution is enabled, the ability to create graphs from Python functions makes it possible to incrementally trade off debugability and interactivity for performance. Functions compiled with defun cannot be inspected with pdb and print statements; however, executing a graph generated by defun sometimes takes less time and memory than eagerly executing the corresponding Python function, since specifying computations as graphs allows for optimizations like automatic buffer reuse and parallelization among ops. Note that executing a defun-compiled function incurs a small constant overhead, so eagerly executing sufficiently small Python functions might take less time than executing their corresponding defun-generated graphs.

For a Python function to be compatible with defun, all of its arguments must be hashable Python objects or lists thereof. Additionally, it must return zero or more tf.Tensor objects.

Executing a graph generated by defun respects device annotations (i.e., all with tf.device directives present in a Python function will also be present in its corresponding graph), but it is not yet possible to execute the generated graphs across multiple machines.

Example Usage

import tensorflow as tf


# A simple example.
def f(x, y):
  return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)

g = tf.contrib.eager.defun(f)

x = tf.constant([[2.0, 3.0]])
y = tf.constant([[3.0, -2.0]])

# `f` and `g` will return the same value, but `g` will be executed as a
# TensorFlow graph.
assert f(x, y).numpy() == g(x, y).numpy()

# `defun` is capable of compiling Python functions that close over Python
# objects, including Tensors and Variables.
def h():
  return f(x, y)

assert (h().numpy() == f(x, y).numpy()).all()

# `defun` automatically lifts variables out of the graphs it creates,
# allowing you to compile the `call` methods of <a href="../../../tf/keras/layers/Layer"><code>tf.keras.layers.Layer</code></a> and
# <a href="../../../tf/keras/Model"><code>tf.keras.Model</code></a> objects.
class MyModel(tf.keras.Model):

  def __init__(self, keep_probability=0.2):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.keep_probability = keep_probability

  def call(self, inputs, training=True):
    x = self.dense2(self.dense1(inputs))
    if training:
      return tf.nn.dropout(x, self.keep_probability)
      return x

model = MyModel() = tf.contrib.eager.defun(
model(x, training=True)  # executes a graph, with dropout
model(x, training=False) # executes a graph, without dropout

# `defun`-compiled functions are differentiable.
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
with tf.GradientTape() as tape:
  outputs = model(x)
gradient = tape.gradient(outputs, model.trainable_variables)
optimizer.apply_gradients((grad, var) for grad, var in zip(gradient,

When using defun, there are subtleties regarding inputs, Python control flow, and variable creation that one should be aware of. For concreteness, let f be a Python function that returns zero or more tf.Tensor objects and let F = defun(f). F builds a graph for each unique input signature it sees, Python control flow is baked into graphs, and operations related to variable initialization are automatically lifted out of the graphs that F generates and placed in the eager context if executing eagerly or into an outer graph otherwise.

Tracing and Input Signatures. The signature of inputs supplied to F is defined to be a tuple of the shapes and dtypes of Tensor-typed arguments and the values of non-Tensor arguments, where "arguments" includes both args and kwargs. Every time F is invoked, the signature of its inputs are inferred. The first time F(*args, **kwargs) is invoked with a particular signature, f(*args, **kwargs) is executed and all the TensorFlow operations that f executes, along with the Tensors that flow between them, are recorded in a TensorFlow graph. F caches this graph and binds it to the inputs' signature; every subsequent invocation of F with inputs conforming to this signature will immediately retrieve the cached graph and pass it to the TensorFlow runtime for execution.

Be aware that because F only logs TensorFlow operations, all the other Python code that f executes will only shape the construction of the graphs that F executes: the Python code won't be executed when the graphs themselves are executed, though it will be executed every time the Python function is traced (and a given Python function might be traced multiple times, once for each input signature it is invoked with). For example, whereas the Python function

import tensorflow as tf
import numpy as np


def add_noise():
  return tf.eye(5) + np.random.randn(5, 5)

will return a different output everytime it is invoked, the compiled function compiled = tf.contrib.eager.defun(add_noise) will return the same value every time it is called, since a particular random offset generated by NumPy will be inserted into the graph as a TensorFlow constant. The solution is to replace the call to np.random.randn with tf.random_normal((5, 5)).

Python Side-Effects A corollary of the previous discussion on tracing is the following: If a Python function f has Python side-effects, then executing f multiple times will not necessarily be semantically equivalent to executing F = tf.contrib.eager.defun(f) multiple times; this difference is due to the fact that defun only captures the subgraph of TensorFlow operations that is constructed when f is called in a graph-building context.

Python Control Flow. The structure of many machine learning computations depend upon whether one is training or validating, and it is common to nest specialized logic under if training: blocks. By mapping each input signature to a unique graph, defun lets users transparently compile such code, as the following code snippet demonstrates:

import tensorflow as tf


def lossy_matmul(W, x, training=True):
  outputs = tf.matmul(W, x)
  if training:
    outputs = tf.nn.dropout(outputs, keep_probability=0.2)
  return outputs

W = tf.random_normal((3, 5))
x = tf.random_normal((5, 1))

# Executes a graph that applies dropout.
lossy_outputs = lossy_matmul(W, x, training=True)

# Executes a graph that does not apply dropout.
exact_outputs = lossy_matmul(W, x, training=False)

On the other hand, because defun generates graphs by tracing and not by source code analysis, it fully unrolls Python for and while loops, potentially creating large graphs. If your Python function has native loops that run for many iterations, consider replacing them with tf.while_loop operations.

When constructing graphs, tf.Tensor objects cannot be used as Python bool objects. This means, for example, that you should replace code in f resembling

if tensor < 10:

with tf.cond(tensor < 10, true_fn, false_fn).

Variables TensorFlow operations related to variable creation and initialization are automatically lifted out of the graphs generated by defun. In practice, this implies that variable creation and initialization only happen the first time F is called, and that variables are reused every time thereafter. Many TensorFlow APIs, like tf.keras.layers.Layer objects, create variables the first time they are called and reuse them thereafter. Automatic variable lifting makes it possible to compile these APIs without extra effort, at the cost of introducing a discrepancy between the semantics of executing Python functions and their corresponding compiled functions. For example:

import tensorflow as tf


def fn():
  x = tf.contrib.eager.Variable(0.0)
  return x.read_value()

# `fn` is a Python function, so x is created, initialized, and destroyed upon
# every invocation
assert fn().numpy() == fn().numpy() == 1.0

compiled = tf.contrib.eager.defun(fn)

# Compiling `fn` with `defun` hoists all variables outside of the generated
# graph, so initialization happens exactly once.
assert compiled().numpy() == 1.0
assert compiled().numpy() == 2.0

Finally, because each input signature is bound to a unique graph, if your Python function constructs tf.contrib.eager.Variable objects, then each graph constructed for that Python function will reference a unique set of variables. To circumvent this problem, we recommend against compiling Python functions that create tf.contrib.eager.Variable objects. Instead, Python functions should either lexically close over tf.contrib.eager.Variable objects or accept them as arguments, preferably encapsulated in an object-oriented container. If you must create variables inside your Python function and you want each graph generated for it to reference the same set of variables, add logic to your Python function that ensures that variables are only created the first time it is called and are reused for every subsequent invocation; note that this is precisely what tf.keras.layers.Layer objects do, so we recommend using them to represent variable-bearing computations whenever possible.


  • func: function to be compiled. If func is None, returns a decorator that can be invoked with a single argument - func. The end result is equivalent to providing all the arguments up front. In other words, defun(compiled=True)(func) is equivalent to defun(func, compiled=True). The former allows the following use case: @tf.contrib.eager.defun(compiled=True) def foo(...): ...

  • compiled: If True, an attempt to compile func with XLA will be made. If it fails, function will be run normally. Experimental. Currently supported only for execution on TPUs. For the vast majority of users, this argument should be False.


If func is not None, returns a callable that will execute the compiled function (and return zero or more tf.Tensor objects). If func is None, returns a decorator that, when invoked with a single func argument, returns a callable equivalent to the case above.