tfp.experimental.nn.util.CallOnce

View source on GitHub

Function object which memoizes the result of create_value_fn().

This object is used to memoize the computation of some function. Upon first call, the user provided create_value_fn is called and with the args/kwargs provided to this object's __call__. On subsequent calls the previous result is returned and regardless of the args/kwargs provided to this object's __call__. To trigger a new evaluation, invoke this.reset() and to identify if a new evaluation will execute (on-demand) invoke this.is_unset(). For an example application of this object, see help(tfp.experimental.nn.util.RandomVariable) and/or help(tfp.util.DeferredTensor).

create_value_fn Python callable which takes any input args/kwargs and returns a value to memoize. (The value is not presumed to be of any particular type.)

create_value_fn

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

trainable_variables Sequence of trainable variables owned by this module and its submodules.

value

variables Sequence of variables owned by this module and its submodules.

Methods

is_unset

View source

Returns True if there is no memoized value and False otherwise.

reset

View source

Removes memoized value which triggers re-eval on subsequent reads.

with_name_scope

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

__call__

View source

Return the memoized value.