tfp.experimental.util.DeferredModule

Wrapper to defer initialization of a tf.Module instance.

DeferredModule is a general-purpose mechanism for creating objects that are 'tape safe', meaning that computation occurs only when an instance method is called, not at construction. This ensures that method calls inside of a tf.GradientTape context will produce gradients to any underlying tf.Variables.

Examples

TFP's built-in Distributions and Bijectors are tape-safe by contract, but this does not extend to cases where computation is required to construct an object's parameters prior to initialization. For example, suppose we want to construct a Gamma distribution with a given mean and variance. In a naive implementation, we would convert these to the Gamma's native concentration and rate parameters when the distribution is constructed. Any future method calls would produce gradients to concentration and rate, but not to the underlying mean and variance:

mean, variance = tf.Variable(3.2), tf.Variable(9.1)
dist = tfd.Gamma(concentration=mean**2 / variance,
                 rate=mean / variance)

with tf.GradientTape() as tape:
  lp = dist.log_prob(5.0)
grads = tape.gradient(lp, [mean, variance])
# ==> `grads` are `[None, None]` !! :-(

To preserve the gradients, we can defer the parameter transformation using DeferredModule. The resulting object behaves just like a tfd.Gamma instance, however, instead of running the Gamma constructor just once, it internally applies the parameter transformation and constructs a new, temporary instance of tfd.Gamma on every method invocation. This ensures that all operations needed to compute a method's return value from any underlying variables are performed every time the method is invoked. A surrounding GradientTape context will therefore be able to trace the full computation.

def gamma_from_mean_and_variance(mean, variance, **kwargs):
  rate = mean / variance
  return tfd.Gamma(concentration=mean * rate, rate=rate, **kwargs)

mean, variance = tf.Variable(3.2), tf.Variable(9.1)
deferred_dist = tfp.experimental.util.DeferredModule(
  build_fn=gamma_from_mean_and_variance,
  mean=mean,  # May be passed by position or by name.
  variance=variance)

with tf.GradientTape() as tape:
  lp = deferred_dist.log_prob(5.0)
grads = tape.gradient(lp, [mean, variance])
# ==> `grads` are defined!

Note that we could have achieved a similar effect by using tfp.util.DeferredTensor to individually defer the concentration and rate parameters. However, this would have been significantly more verbose, and would not share any computation between the two parameter transformations. In general, DeferredTensor is often idiomatic for simple transformations of a single value, while DeferredModule may be preferred for transformations that operate on multiple values and/or contain multiple steps.

Caveats

Objects derived from a DeferredModule are no longer deferred, so they will not preserve gradients. For example, slicing into a deferred Distribution yields a new, concrete Distribution instance:

def normal_from_log_scale(scaled_loc, log_scale):
  return tfd.Normal(loc=5 * scaled_loc, scale=tf.exp(log_scale))

dist = tfp.experimental.util.DeferredModule(
  build_fn=normal_from_log_scale,
  scaled_loc=tf.Variable([1., 2., 3.]),
  log_scale=tf.Variable([1., 1., 1.]))
dist.batch_shape  # ==> [3]
len(dist.trainable_variables)  # ==> 2

slice = dist[:2]  # Instantiates a new, non-deferred Distribution.
slice.batch_shape  # ==> [2]
len(slice.trainable_variables)  # ==> 0 (!)

# If needed, we could defer the slice with another layer of wrapping.
deferred_slice = tfp.experimental.util.DeferredModule(
  build_fn=lambda d: d[:2],
  d=dist)
len(deferred_slice.trainable_variables)  # ==> 2

build_fn Python callable specifying a deferred transformation of the provided arguments. This must have signature module = build_fn(*args, **kwargs). The return value module is an instance of tf.Module.
*args Optional positional arguments to build_fn.
also_track Optional instance or structure of instances of tf.Variable and/or tf.Module, containing any additional trainable variables that the build_fn may access beyond the given args and kwargs. This ensures that such variables will be correctly tracked in self.trainable_variables. Default value: None.
**kwargs Optional keyword arguments to build_fn.

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.

Methods

with_name_scope

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

__abs__

Return the absolute value of the argument.

__add__

Same as a + b.

__and__

Same as a & b.

__bool__

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

__call__

View source

__contains__

Same as b in a (note reversed operands).

__enter__

View source

__eq__

Same as a == b.

__exit__

View source

__floordiv__

Same as a // b.

__ge__

Same as a >= b.

__getitem__

Same as a[b].

__gt__

Same as a > b.

__invert__

Same as ~a.

__iter__

iter(iterable) -> iterator iter(callable, sentinel) -> iterator

Get an iterator from an object. In the first form, the argument must supply its own iterator, or be a sequence. In the second form, the callable is called until it returns the sentinel.

__le__

Same as a <= b.

__len__

Return the number of items in a container.

__lshift__

Same as a << b.

__lt__

Same as a < b.

__matmul__

Same as a @ b.

__mod__

Same as a % b.

__mul__

Same as a * b.

__ne__

Same as a != b.

__neg__

Same as -a.

__or__

Same as a | b.

__pos__

Same as +a.

__pow__

Equivalent to baseexp with 2 arguments or baseexp % mod with 3 arguments

Some types, such as ints, are able to use a more efficient algorithm when invoked using the three argument form.

__radd__

Same as a + b.

__rand__

Same as a & b.

__rfloordiv__

Same as a // b.

__rlshift__

Same as a << b.

__rmatmul__

Same as a @ b.

__rmod__

Same as a % b.

__rmul__

Same as a * b.

__ror__

Same as a | b.

__rpow__

Equivalent to baseexp with 2 arguments or baseexp % mod with 3 arguments

Some types, such as ints, are able to use a more efficient algorithm when invoked using the three argument form.

__rrshift__

Same as a >> b.

__rshift__

Same as a >> b.

__rsub__

Same as a - b.

__rtruediv__

Same as a / b.

__rxor__

Same as a ^ b.

__sub__

Same as a - b.

__truediv__

Same as a / b.

__xor__

Same as a ^ b.