Save the date! Google I/O returns May 18-20 Register now


Automagically generate CompositeTensor behavior for cls.

CompositeTensor objects are able to pass in and out of tf.function and tf.while_loop, or serve as part of the signature of a TF saved model.

The contract of auto_composite_tensor is that all init args and kwargs must have corresponding public or private attributes (or properties). Each of these attributes is inspected (recursively) to determine whether it is (or contains) Tensors or non-Tensor metadata. list and tuple attributes are supported, but must either contain only Tensors (or lists, etc, thereof), or no Tensors. E.g.,

  • object.attribute = [1., 2., 'abc'] # valid
  • object.attribute = [tf.constant(1.), [tf.constant(2.)]] # valid
  • object.attribute = ['abc', tf.constant(1.)] # invalid

If the object has a _composite_tensor_shape_parameters field (presumed to have tuple of str value), the flattening code will use tf.get_static_value to attempt to preserve shapes as static metadata, for fields whose name matches a name specified in that field. Preserving static values can be important to correctly propagating shapes through a loop.

If the decorated class A does not subclass CompositeTensor, a new class will be generated, which mixes in A and CompositeTensor.

To avoid this extra class in the class hierarchy, we suggest inheriting from auto_composite_tensor.AutoCompositeTensor, which inherits from CompositeTensor and implants a trivial _type_spec @property. The @auto_composite_tensor decorator will then overwrite this trivial _type_spec @property. The trivial one is necessary because _type_spec is an abstract property of CompositeTensor, and a valid class instance must be created before the decorator can execute -- without the trivial _type_spec property present, ABCMeta will throw an error! The user may thus do any of the following:

class MyClass(tfp.experimental.AutoCompositeTensor):

mc = MyClass()
# ==> MyClass

No CompositeTensor base class (ok, but changes expected types)

class MyClass(object):

mc = MyClass()
# ==> MyClass_AutoCompositeTensor

CompositeTensor base class, requiring trivial _type_spec

from tensorflow.python.framework import composite_tensor
class MyClass(composite_tensor.CompositeTensor):
  def _type_spec(self):  # will be overwritten by @auto_composite_tensor

mc = MyClass()
# ==> MyClass

Full usage example

class Adder(tfp.experimental.AutoCompositeTensor):
  def __init__(self, x, y, name=None):
    with tf.name_scope(name or 'Adder') as name:
      self._x = tf.convert_to_tensor(x)
      self._y = tf.convert_to_tensor(y)
      self._name = name

  def xpy(self):
    return self._x + self._y

def body(obj):
  return Adder(obj.xpy(), 1.),

result, = tf.while_loop(
    cond=lambda _: True,
    loop_vars=(Adder(1., 1.),),

result.xpy()  # => 5.

cls The class for which to create a CompositeTensor subclass.
omit_kwargs Optional sequence of kwarg names to be omitted from the spec.

composite_tensor_subclass A subclass of cls and TF CompositeTensor.