Concrete functions

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In the guide to AutoGraph and tf.functions you saw how to use tf.function. This guide dives into the details of:

  • tf.function Tracing
  • tf.function Signatures
  • The Concrete functions generated by tracing:
    • How to access them
    • How to use them

These details only become important:

  • If you're experiencing performance issues due to undesired tracing of a tf.funcion.
  • When you need precise control over the TensorFlow Graphs generated by tf.function. For example for exporting the model to TensorFlow Lite using tf.lite.Converter.from_concrete_functions.

Background

In TensorFlow 2, eager execution is on by default. TensorFlow's eager execution is an imperative programming environment that evaluates operations immediately, without building graphs. Operations return values instead of constructing a computational graph to run later. Here is a detailed guide on eager execution.

Running imperatively makes development and debugging more interactive, but doesn't allow for easy exporting.

The tf.function API makes it possible to save models as graphs.

Terminology

The following terminology is used in this document:

  • Signature - A description of the inputs and outputs for a set of operations.
  • Polymorphic function - Python callable that encapsulates several concrete function graphs behind one API.
  • Concrete function - Graph with a single signature.

Setup

import traceback
import textwrap

!pip install -q tf_nightly

ERROR: tensorflow 2.1.0 has requirement gast==0.2.2, but you'll have gast 0.3.3 which is incompatible.
WARNING: You are using pip version 20.0.2; however, version 20.1 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import tensorflow as tf

Create a tf.function

Annotating a function with tf.function generates a polymorphic function containing those operations. All operations that are not annotated with tf.function will be evaluated with eager execution. The examples below show a quick example of tf.function usage.

@tf.function
def square(x):
  return x*x
square(2).numpy()
4

Remember that the python decorator syntax just calls the decorator with the decorated object as input:

def pow(x,y):
  return x ** y

pow = tf.function(pow)
pow(3,4).numpy()
81

Attach a tf.function method to a tf.Module

The tf.function can be optionally stored as part of a tf.Module object. The tf.Module class provides features for tracking variables and saving checkpoints and models.

Classes like keras.layers.Layer and keras.Model are subclasses of Module.

class Pow(tf.Module):
  def __init__(self, exponent):
    self.exponent = tf.Variable(exponent, dtype = tf.float32, name='Pow/exponent')

  @tf.function
  def __call__(self, x):
    return x ** self.exponent
pow = Pow(3)
pow.variables
(<tf.Variable 'Pow/exponent:0' shape=() dtype=float32, numpy=3.0>,)
pow(tf.constant(2.0)).numpy()
8.0
pow.exponent.assign(4)
pow(tf.constant(2.0)).numpy()
16.0
tf.saved_model.save(pow, 'pow')
INFO:tensorflow:Assets written to: pow/assets

reloaded_pow = tf.saved_model.load('pow')
reloaded_pow(tf.constant(3.0)).numpy()
81.0

Assign a tf.function as an attribute

If you assign a tf.Module or a tf.function as an attribute of a module it will be serialized as well:

mod = tf.Module()
mod.increment_by = tf.Variable(2.0)

@tf.function
def increment(x):
  return x+mod.increment_by

mod.inc = increment
mod.inc(tf.constant(1.0)).numpy()
3.0
mod.cube = Pow(3)
mod.cube(tf.constant(2.0)).numpy()
8.0
mod.variables
(<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
 <tf.Variable 'Pow/exponent:0' shape=() dtype=float32, numpy=3.0>)
tf.saved_model.save(mod, 'mod')
reloaded_mod = tf.saved_model.load('mod')
INFO:tensorflow:Assets written to: mod/assets

reloaded_mod.inc(4.0).numpy()
6.0
reloaded_mod.cube(4.0).numpy()
64.0

Interoperability with tf.keras

Keras classes like keras.Model and keras.layers.Layer are fully compatible with tf.function and tf.Module.

For example, build a simple model:

linear = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
linear.compile(optimizer='adam', loss='mean_squared_error')
linear.fit(x=[-1, 0, 1, 2, 3, 4], y=[-3, -1, 1, 3, 5, 7], epochs=50, verbose=0)
<tensorflow.python.keras.callbacks.History at 0x7feef4269e10>
linear(tf.constant([[1],[2]]))
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[-0.08265691],
       [-0.21497664]], dtype=float32)>

Inspect it's variables

linear.variables
[<tf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[-0.13231973]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04966282], dtype=float32)>]

Now attach it to a tf.Module:

module = tf.Module()
module.linear = linear

The tf.Module also tracks the tf.Variables:

module.variables
(<tf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[-0.13231973]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04966282], dtype=float32)>)

The tf.Module will export the contents of the keras.Model as well:

tf.saved_model.save(module,'module')
INFO:tensorflow:Assets written to: module/assets

reloaded = tf.saved_model.load('module')
reloaded.linear([[1.0]])
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-0.08265691]], dtype=float32)>

Tracing

The objects returned from tf.function are polymorphic functions. They will accept python objects, or tf.Tensors with any shape or tf.dtype as input.

In the background TensorFlow builds tf.Graphs representing the calculation. This graph is wrapped in a python callable: a concrete function. Each concrete function can only handle a single input signature.

tf.function traces the python function each time in needs to create a concrete function. The easiest way to see when a function is traced is to add a call to print:

@tf.function
def mul(a, b):
  print('Tracing:\n    {a}\n    {b}\n'.format(a=a, b=b))
  return a*b

Dtypes and shapes

If you call the polymorphic function with two different types of input, it will trace once for each:

# Trace with ints
mul(tf.constant(2), tf.constant(3)).numpy()
Tracing:
    Tensor("a:0", shape=(), dtype=int32)
    Tensor("b:0", shape=(), dtype=int32)


6
# Trace with floats
mul(tf.constant(2.0), tf.constant(3.0)).numpy()
Tracing:
    Tensor("a:0", shape=(), dtype=float32)
    Tensor("b:0", shape=(), dtype=float32)


6.0

When you call it again with the same input types, it dispatches to an existing function instead of tracing:

# Call with ints again => no trace
mul(tf.constant(10), tf.constant(10))
<tf.Tensor: shape=(), dtype=int32, numpy=100>

Changing the sizes of the inputs also triggers a trace (setting tf.function(experimental_relax_shapes=True) may reduce this):

# Trace with vectors
mul(tf.constant([1.0,3.0]), tf.constant(3.0)).numpy()
Tracing:
    Tensor("a:0", shape=(2,), dtype=float32)
    Tensor("b:0", shape=(), dtype=float32)


array([3., 9.], dtype=float32)
# Trace with different-sized vectors
mul(tf.constant([1.0,2.0,3.0, 4.0]), tf.constant(3.0))
Tracing:
    Tensor("a:0", shape=(4,), dtype=float32)
    Tensor("b:0", shape=(), dtype=float32)


<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 3.,  6.,  9., 12.], dtype=float32)>

Immutable python objects

If you pass an immutable python object, like a int, str, or tuple to a tf.function, it executes a trace for each value of those python objects.

This is useful to control what gets included in the tf.Graph (See: The Autograph Guide for more details).

@tf.function
def mul(a, b):
  print('Tracing:\n    {a}\n    {b}\n'.format(a=a, b=b))
  return a*b
# Trace for a=3.0
mul(3.0, tf.constant(3.0)).numpy()
Tracing:
    3.0
    Tensor("b:0", shape=(), dtype=float32)


9.0
# Don't trace for a=3.0 the second time:
mul(3.0, tf.constant(3.0)).numpy()
9.0

This loop traces the function for each unique int:

@tf.function
def power(a,b):
  print('Tracing "power": a={}'.format(a))
  return a**b
p = tf.constant(2)
for n in range(12):
  power(n,p)
Tracing "power": a=0
Tracing "power": a=1
Tracing "power": a=2
Tracing "power": a=3
Tracing "power": a=4
WARNING:tensorflow:5 out of the last 5 calls to <function power at 0x7feef4183e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Tracing "power": a=5
WARNING:tensorflow:6 out of the last 6 calls to <function power at 0x7feef4183e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Tracing "power": a=6
WARNING:tensorflow:7 out of the last 7 calls to <function power at 0x7feef4183e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Tracing "power": a=7
WARNING:tensorflow:8 out of the last 8 calls to <function power at 0x7feef4183e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Tracing "power": a=8
WARNING:tensorflow:9 out of the last 9 calls to <function power at 0x7feef4183e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Tracing "power": a=9
WARNING:tensorflow:10 out of the last 10 calls to <function power at 0x7feef4183e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Tracing "power": a=10
WARNING:tensorflow:11 out of the last 11 calls to <function power at 0x7feef4183e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Tracing "power": a=11
WARNING:tensorflow:11 out of the last 11 calls to <function power at 0x7feef4183e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

On the second run each int has been traced, so there's no tracing to do:

p = tf.constant(2)
for n in range(12):
  power(n,p)

To avoid excess retracing be sure to pass a tf.Tensor instead of python numbers or strings:

p = tf.constant(2)
for n in tf.range(12):
  power(n,p)
Tracing "power": a=Tensor("a:0", shape=(), dtype=int32)

To shut off tracing altogether, pass a signature to the tf.function decorator:

@tf.function(input_signature=(
    tf.TensorSpec(shape=[], dtype=tf.float32),
    tf.TensorSpec(shape=[], dtype=tf.float32),)
)
def power_with_sig(a,b):
  print('Tracing "power_with_sig"')
  return a**b
power_with_sig(3.0, 3.0).numpy()
Tracing "power_with_sig"

27.0
try:
  power_with_sig(tf.constant([1.0,2.0,3.0]),tf.constant(3.0))
  assert False
except ValueError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-45-344551274fb0>", line 2, in <module>
    power_with_sig(tf.constant([1.0,2.0,3.0]),tf.constant(3.0))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32),
    tf.Tensor(3.0, shape=(), dtype=float32))
  input_signature: (
    TensorSpec(shape=(), dtype=tf.float32, name=None),
    TensorSpec(shape=(), dtype=tf.float32, name=None))

Example: Dropout

Retracing for specific values gives you control over what code gets generated by the tf.function.

class Dropout(tf.Module):
  def __init__(self, rate, name=None):
    super(Dropout, self).__init__(name)
    self.rate = tf.Variable(rate, dtype = tf.float32, trainable=False)

  @tf.function
  def __call__(self, x, training=True):
    print(textwrap.dedent("""
                          Tracing "Dropout":
                              training = {}
                              x = {}
                              name = {:s}
                          """.format(training, x, self.name)))
    if training:
      print('    - Train branch\n')
      mask = tf.random.uniform(x.shape) > self.rate
      return x * tf.cast(mask, tf.float32)/self.rate
    else:
      print('    - Test branch\n')
      return x

Create an instance of this simple Dropout layer:

dropout = Dropout(0.5)

The first time you call it with a python training=True as input, it traces the training branch:

dropout(tf.range(10, dtype=tf.float32), training=True).numpy()
WARNING:tensorflow:AutoGraph could not transform <function dedent at 0x7fef8c39aea0> and will run it as-is.
Cause: for/else statement not yet supported
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function dedent at 0x7fef8c39aea0> and will run it as-is.
Cause: for/else statement not yet supported
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Tracing "Dropout":
    training = True
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Train branch


array([ 0.,  0.,  0.,  6.,  0., 10., 12., 14.,  0., 18.], dtype=float32)

The second time, it doesn't need to re-trace the branch:

dropout(tf.range(10, dtype=tf.float32), training=True).numpy()
array([ 0.,  2.,  0.,  6.,  0.,  0.,  0., 14., 16.,  0.], dtype=float32)

Passing training=False triggers a trace on the first run since this is a different python value:

dropout(tf.range(10, dtype=tf.float32), training=False).numpy()

Tracing "Dropout":
    training = False
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Test branch


array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)
dropout(tf.range(10, dtype=tf.float32), training=False).numpy()
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

If you pass a bool tensor, it uses TensorFlow autograph rewrite the if to a tf.condm and traces both branches:

dropout(tf.range(10, dtype=tf.float32), training=tf.constant(False)).numpy()

Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Train branch

    - Test branch


array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

This captures the control flow in a single concrete function.

 dropout(tf.range(10, dtype=tf.float32), training=tf.constant(True)).numpy()
array([ 0.,  0.,  4.,  6.,  0.,  0., 12., 14., 16., 18.], dtype=float32)
dropout(tf.range(10, dtype=tf.float32), training=tf.constant(False)).numpy()
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

Other python objects

Since the generated tf.Graphs cannot contain complex python objects, these are included by tracing and variable capture.

The tf.function runs a separate trace for each instance. So each trace includes its own variables, and can set its behavior based on the instance.

The most common usage is on methods of Module, Layer or Module:

dropout_a = Dropout(0.5, name='dropout_a')
print(dropout_a(tf.range(10, dtype=tf.float32), True).numpy())
print(dropout_a(tf.range(10, dtype=tf.float32), True).numpy())

Tracing "Dropout":
    training = True
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout_a

    - Train branch

[ 0.  0.  4.  0.  8. 10. 12. 14. 16.  0.]
[ 0.  2.  0.  6.  8. 10.  0. 14. 16.  0.]

dropout_b = Dropout(0.5, name='dropout_b')
print(dropout_b(tf.range(10, dtype=tf.float32), True).numpy())
print(dropout_b(tf.range(10, dtype=tf.float32), True).numpy())

Tracing "Dropout":
    training = True
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout_b

    - Train branch

Warning:tensorflow:5 out of the last 10 calls to <function Dropout.__call__ at 0x7feee855a378> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please  define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
[ 0.  0.  4.  0.  8.  0.  0. 14.  0.  0.]
[ 0.  0.  4.  6.  0. 10.  0.  0.  0. 18.]

But the behavior is the same on a stand-alone tf.function.

@tf.function
def run(callable, x):
  print('Tracing "run":\n    callable = {}\n    x = {}\n'.format(callable, x))
  return callable(x)
def plus_1(x):
  return x+1

print(run(plus_1, tf.constant(2.0)).numpy())
print(run(plus_1, tf.constant(5.0)).numpy())
Tracing "run":
    callable = <function plus_1 at 0x7feee855f158>
    x = Tensor("x:0", shape=(), dtype=float32)

3.0
6.0

The tracing one tf.function can trigger tracing in another:

print(run(dropout, tf.range(10.0)).numpy())
print(run(dropout, tf.range(10.0)).numpy())
Tracing "run":
    callable = <__main__.Dropout object at 0x7feee85a3e10>
    x = Tensor("x:0", shape=(10,), dtype=float32)


Tracing "Dropout":
    training = True
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Train branch

[ 0.  2.  0.  6.  0.  0. 12.  0.  0. 18.]
[ 0.  0.  0.  0.  8. 10.  0.  0. 16.  0.]

Weak references

For example here's a tf.function that refers to var from the enclosing scope:

@tf.function
def plus_var(x):
  print('Tracing "plus_var":\n    x = {}\n    var = {}\n\n'.format(x, var.name))
  return x + var

Trace the function with one variable:

var = tf.Variable(1, name="IntVar")
plus_var(tf.constant([1,2])).numpy()
Tracing "plus_var":
    x = Tensor("x:0", shape=(2,), dtype=int32)
    var = IntVar:0



array([2, 3], dtype=int32)

And with another variable:

var = tf.Variable(2.0, name="FloatVar")
plus_var(tf.constant([2.0, 10.0])).numpy()
Tracing "plus_var":
    x = Tensor("x:0", shape=(2,), dtype=float32)
    var = FloatVar:0



array([ 4., 12.], dtype=float32)

That worked, but because you no longer have a reference to "IntVar", that first trace is broken:

try:
  plus_var(tf.constant([1,2])).numpy()
  assert False
except tf.errors.FailedPreconditionError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-65-91a1d48c472f>", line 2, in <module>
    plus_var(tf.constant([1,2])).numpy()
tensorflow.python.framework.errors_impl.FailedPreconditionError:  Error while reading resource variable _AnonymousVar18 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar18/N10tensorflow3VarE does not exist.
     [[node add/ReadVariableOp (defined at <ipython-input-62-2bdc80b2c0ef>:4) ]] [Op:__inference_plus_var_2349]

Function call stack:
plus_var


Accessing concrete function

In the previous section you saw the conditions for triggering a new trace of a polymorphic tf.function. Each trace generates a new concrete function.

When you save tf.Module as a tf.saved_model It's those concrete functions that define the tf.Graphs that are exported. You don't save a tf.function you save the concrete functions that are created by tracing.

To get a concrete function from the polymorphic tf.function you need to define the signature. Either:

  • Pass an input_signature to tf.function, and call the get_concrete_function() method.
  • Pass a list of tf.TensorSpecs to get_concrete_function: tf.TensorSpec(shape=[1], dtype=tf.float32).
  • Pass an example tensor of the correct shape and type to get_concrete_function: tf.constant(1., shape=[1]).

The following example shows how to define the input_signature parameter for tf.function.

Using input_signature

Specify input tensors in the call to tf.function as shown below. This tf.functioncan only execute on tensors that match the specified signatutre.

A None in the shape acts a wildcard. So this these tf.TensroSpec say "A float32 vector of any length".

This pattern can be very important if your tf.function is expected to handle sequences of different length, or images of different sizes for each batch (See Transformer and Deep Dream tutrorials for example).

@tf.function(input_signature=(
    tf.TensorSpec(shape=[None], dtype=tf.float32),
    tf.TensorSpec(shape=[None], dtype=tf.float32),)
)
def power_with_sig(a,b):
  print('Tracing "power_with_sig"\n')
  return a**b

Calling get_concrete_function will execute the trace (if necessary), and return a concrete function.

p = power_with_sig.get_concrete_function()
type(p)
Tracing "power_with_sig"


tensorflow.python.eager.function.ConcreteFunction
p(tf.constant([2.0,3.0,4.0]), tf.constant([5.0,4.0,3.0])).numpy()
array([32., 81., 64.], dtype=float32)

Using get_concrete_function

@tf.function
def power(a,b):
  print('Tracing "power"\n')
  return a**b
float_power = power.get_concrete_function(
  a = tf.TensorSpec(shape=[], dtype=tf.float32),
  b = tf.TensorSpec(shape=[], dtype=tf.float32))
Tracing "power"


float_power(tf.constant(3.0),tf.constant(3.0))
<tf.Tensor: shape=(), dtype=float32, numpy=27.0>

Remember that you can also pass tensors to get_concrete_function, in that case it returns the concrete function that would run for those inputs:

row = tf.range(10)
col = tf.constant([[1],[2],[3]])

concrete_power = power.get_concrete_function(a = row, b = col)
concrete_power(row, col).numpy()
Tracing "power"


array([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
       [  0,   1,   4,   9,  16,  25,  36,  49,  64,  81],
       [  0,   1,   8,  27,  64, 125, 216, 343, 512, 729]], dtype=int32)

Using a concrete function

A concrete function only accepts tensors as input:

float_power(tf.constant(2.0), tf.constant(3.0)).numpy()
8.0
try:
  float_power(2.0,3.0)
  assert False
except (ValueError, TypeError):
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1725, in _call_impl
    cancellation_manager)
TypeError: power(a, b): expected argument #0(zero-based) to be a Tensor; got float (2.0)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-74-a37a0a54e52a>", line 2, in <module>
    float_power(2.0,3.0)
TypeError: power(a, b) expected a Tensor in a, but got float value 2.0

It also only accepts inputs of the correct dtype:

try:
  float_power(tf.constant(1),tf.constant(3))
  assert False
except tf.errors.InvalidArgumentError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-75-9b54dd5e8642>", line 2, in <module>
    float_power(tf.constant(1),tf.constant(3))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_power_2382 as input #0(zero-based) was expected to be a float tensor but is a int32 tensor [Op:__inference_power_2382]

But it will try to execute even if the input tensors do not match the expected shape:

float_power(tf.constant([1.,2.,3.,4.,5.]),tf.constant(3.)).numpy()
array([  1.,   8.,  27.,  64., 125.], dtype=float32)
try:
  float_power(tf.constant([1.,2.,3.]),tf.constant([4., 5.])).numpy()
  assert False
except tf.errors.InvalidArgumentError:  
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-77-c60087f7e9d3>", line 2, in <module>
    float_power(tf.constant([1.,2.,3.]),tf.constant([4., 5.])).numpy()
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Incompatible shapes: [3] vs. [2]
     [[node pow (defined at <ipython-input-69-7cca90b55f0d>:4) ]] [Op:__inference_power_2382]

Errors may have originated from an input operation.
Input Source operations connected to node pow:
 b (defined at <ipython-input-70-63ea774dd467>:3)

Function call stack:
power


By inspecting the concrete function you can see its inputs and outputs:

print(float_power.structured_input_signature)
print(float_power.structured_outputs)
((TensorSpec(shape=(), dtype=tf.float32, name='a'), TensorSpec(shape=(), dtype=tf.float32, name='b')), {})
Tensor("Identity:0", shape=(), dtype=float32)

Python Objects in signatures

As you saw when tracing, each python object generates a new trace. Concrete functions represent a single tf.Graph, they don't do any retracing. When you call get_concrete_function with a python object as one of the arguments the object is bound to the function.

cube = power.get_concrete_function(
    a = tf.TensorSpec([], dtype=tf.float32),
    b = 3.0)
Tracing "power"


This cube function no longer has a b argument:

print(cube.structured_input_signature)
((TensorSpec(shape=(), dtype=tf.float32, name='a'), 3.0), {})

cube(tf.constant(10.0)).numpy()
999.99994

This is very similar to the way that standard python classes bind methods, and applies equally when you run get_concrete_function from a method:

class Greeter(object):
  def __init__(self, greeting):
    self.greeting = greeting

  def greet(self, who):
    return " ".join([self.greeting, who])

p = Greeter("Hello")
m = p.greet
print(m)
<bound method Greeter.greet of <__main__.Greeter object at 0x7feef400ac18>>

print(m("TensorFlow!"))
Hello TensorFlow!

When you have a tf.function decorating a method, similar rules apply:

class MyModel(tf.Module):
  def __init__(self, ins, outs):
    initializer = tf.initializers.GlorotNormal()
    self.W = tf.Variable(initializer([ins, outs]))
    self.B = tf.Variable(tf.zeros([outs], dtype = tf.float32))

  @tf.function
  def run(self, x):
    print('Tracing "MyModule":\n    x={}\n'.format(x))
    return tf.matmul(x, self.W)+self.B
mod = MyModel(ins=5, outs=3)
mod.run([[1.0,1.0,1.0, 1.0, 1.0]]).numpy()
Tracing "MyModule":
    x=[[1.0, 1.0, 1.0, 1.0, 1.0]]


array([[-0.7157905, -0.7796955, -1.6723082]], dtype=float32)

If you call the method's .get_concrete_function, the self is automatically bound as the first argument:

concrete_run = mod.run.get_concrete_function(x = tf.TensorSpec([None, None]))
Tracing "MyModule":
    x=Tensor("x:0", shape=(None, None), dtype=float32)


concrete_run(tf.constant([[1.0,1.0,1.0, 1.0, 1.0],
                          [2.0,2.0,2.0, 2.0, 2.0]])).numpy()
array([[-0.7157905, -0.7796955, -1.6723082],
       [-1.431581 , -1.559391 , -3.3446164]], dtype=float32)

See how self is no longer part of the input signature:

print(concrete_run.structured_input_signature)
print(concrete_run.structured_outputs)
((TensorSpec(shape=(None, None), dtype=tf.float32, name='x'),), {})
Tensor("Identity:0", shape=(None, 3), dtype=float32)

Changes for TensorFlow 2.3

The following changes are currently available in TensorFlow nightly, and will be available in TensorFlow 2.3.

Inspecting concrete function signatures

Printing a ConcreteFunction displays a summary of its input arguments (with types) and its output type:

print(float_power)
ConcreteFunction power(a, b)
  Args:
    a: float32 Tensor, shape=()
    b: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

For polymorphic functions, the pretty_printed_concrete_signatures() method can be used to display a list of all signatures that have been traced so far:

print(power.pretty_printed_concrete_signatures())
power(a, b)
  Args:
    a: int32 Tensor, shape=(10,)
    b: int32 Tensor, shape=(3, 1)
  Returns:
    int32 Tensor, shape=(3, 10)

power(a, b)
  Args:
    a: float32 Tensor, shape=()
    b: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

power(a, b=3.0)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

Bound arguments

As discussed above, when a concrete function is created, any arguments that are set to non-Tensor values are bound to those values. Prior to TensorFlow 2.3, these bound arguments were simply removed from the concrete function's signature. Starting with TensorFlow 2.3, they remain present in the signature, but have their default value set to the bound value. For example, the argument b has its default value set to 2 when we create the following concrete function:

square = power.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
Tracing "power"

ConcreteFunction power(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

When calling square, you may either omit b, or explicitly set it to the value that was used during tracing:

print(square(tf.constant(5.0)))  # ok: use default value
print(square(tf.constant(5.0), b=2))  # ok: explicit value == bound value
tf.Tensor(25.0, shape=(), dtype=float32)
tf.Tensor(25.0, shape=(), dtype=float32)

Setting a bound argument to any value other than its default is an error:

try:
  print(square(tf.constant(10.0), b=4.0).numpy())  # error: explicit value != bound value
  assert False
except TypeError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1725, in _call_impl
    cancellation_manager)
TypeError: power(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-94-39e32e913430>", line 2, in <module>
    print(square(tf.constant(10.0), b=4.0).numpy())  # error: explicit value != bound value
TypeError: ConcreteFunction power(a, b) was constructed with int value 2 in b, but was called with float value 4.0

CompositeTensor arguments

Starting with TensorFlow 2.3, you can define concrete functions that accept CompositeTensors (such as tf.RaggedTensor, tf.Dataset iterators, and tf.SparseTensor). E.g.:

ragged_power = power.get_concrete_function(
    a = tf.RaggedTensorSpec([None, None], dtype=tf.float32),
    b = tf.TensorSpec([], dtype=tf.float32)
)
print(ragged_power(tf.ragged.constant([[1.0, 2.0], [3.0]]),
                   tf.constant(2.0)))
Tracing "power"

<tf.RaggedTensor [[1.0, 4.0], [9.0]]>

Nested arguments

Starting with TensorFlow 2.3, you can define concrete functions that accept nested dicts, lists, tuples, namedtuples, and attrs of tensors (or composite tensors). E.g.:

@tf.function
def sum_features(features):
  assert isinstance(features, dict)
  return sum(t for t in features.values())

features = {
    'a': tf.constant([[1.0], [2.0], [8.0]]),
    'b': tf.constant([[5.0], [6.0], [0.0]]),
    'c': tf.ragged.constant([[10.0], [20.0, 30.0, 40.0], [50.0]])
}
print(sum_features(features))

concrete_sum_features = sum_features.get_concrete_function(features)
print(concrete_sum_features(features))
<tf.RaggedTensor [[16.0], [28.0, 38.0, 48.0], [58.0]]>
<tf.RaggedTensor [[16.0], [28.0, 38.0, 48.0], [58.0]]>

The value passed to a concrete function must have the same nesting structure that was used when the concrete function was traced. If you pass in a different structure, then a TypeError is raised, with information about the structure that was accepted and the structure that was used:

try:
  concrete_sum_features({'a': tf.constant([[1.0]])})
  assert False
except TypeError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1756, in _call_with_flat_signature
    args.append(kwargs.pop(compat.as_str(keyword)))
KeyError: 'features_1'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1725, in _call_impl
    cancellation_manager)
TypeError: sum_features(features, features_1, features_2, features_3) missing required arguments: features_1, features_2, features_3

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-97-2ec957efcc80>", line 2, in <module>
    concrete_sum_features({'a': tf.constant([[1.0]])})
TypeError: sum_features(features): argument features had incorrect type
  expected: {'a': Tensor, 'b': Tensor, 'c': RaggedTensor}
       got: {'a': EagerTensor}