tfp.bijectors.MaskedAutoregressiveFlow

Affine MaskedAutoregressiveFlow bijector.

Inherits From: Bijector

The affine autoregressive flow [(Papamakarios et al., 2016)][3] provides a relatively simple framework for user-specified (deep) architectures to learn a distribution over continuous events. Regarding terminology,

'Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density by an invertible transformation with tractable Jacobian.' [(Papamakarios et al., 2016)][3]

In other words, the 'autoregressive property' is equivalent to the decomposition, p(x) = prod{ p(x[perm[i]] | x[perm[0:i]]) : i=0, ..., d } where perm is some permutation of {0, ..., d}. In the simple case where the permutation is identity this reduces to: p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }.

In TensorFlow Probability, 'normalizing flows' are implemented as tfp.bijectors.Bijectors. The forward 'autoregression' is implemented using a tf.while_loop and a deep neural network (DNN) with masked weights such that the autoregressive property is automatically met in the inverse.

A TransformedDistribution using MaskedAutoregressiveFlow(...) uses the (expensive) forward-mode calculation to draw samples and the (cheap) reverse-mode calculation to compute log-probabilities. Conversely, a TransformedDistribution using Invert(MaskedAutoregressiveFlow(...)) uses the (expensive) forward-mode calculation to compute log-probabilities and the (cheap) reverse-mode calculation to compute samples. See 'Example Use' [below] for more details.

Given a shift_and_log_scale_fn, the forward and inverse transformations are (a sequence of) affine transformations. A 'valid' shift_and_log_scale_fn must compute each shift (aka loc or 'mu' in [Germain et al. (2015)][1]) and log(scale) (aka 'alpha' in [Germain et al. (2015)][1]) such that each are broadcastable with the arguments to forward and inverse, i.e., such that the calculations in forward, inverse [below] are possible.

For convenience, tfp.bijectors.AutoregressiveNetwork is offered as a possible shift_and_log_scale_fn function. It implements the MADE architecture [(Germain et al., 2015)][1]. MADE is a feed-forward network that computes a shift and log(scale) using masked dense layers in a deep neural network. Weights are masked to ensure the autoregressive property. It is possible that this architecture is suboptimal for your task. To build alternative networks, either change the arguments to tfp.bijectors.AutoregressiveNetwork or use some other architecture, e.g., using tf.keras.layers.

Assuming shift_and_log_scale_fn has valid shape and autoregressive semantics, the forward transformation is

def forward(x):
  y = zeros_like(x)
  event_size = x.shape[-event_dims:].num_elements()
  for _ in range(event_size):
    shift, log_scale = shift_and_log_scale_fn(y)
    y = x * tf.exp(log_scale) + shift
  return y

and the inverse transformation is

def inverse(y):
  shift, log_scale = shift_and_log_scale_fn(y)
  return (y - shift) / tf.exp(log_scale)

Notice that the inverse does not need a for-loop. This is because in the forward pass each calculation of shift and log_scale is based on the y calculated so far (not x). In the inverse, the y is fully known, thus is equivalent to the scaling used in forward after event_size passes, i.e., the 'last' y used to compute shift, log_scale. (Roughly speaking, this also proves the transform is bijective.)

The bijector_fn argument allows specifying a more general coupling relation, such as the LSTM-inspired activation from [4], or Neural Spline Flow [5]. It must logically operate on each element of the input individually, and still obey the 'autoregressive property' described above. The forward transformation is

def forward(x):
  y = zeros_like(x)
  event_size = x.shape[-event_dims:].num_elements()
  for _ in range(event_size):
    bijector = bijector_fn(y)
    y = bijector.forward(x)
  return y

and inverse transformation is

def inverse(y):
    bijector = bijector_fn(y)
    return bijector.inverse(y)

Examples

tfd = tfp.distributions
tfb = tfp.bijectors

dims = 2

# A common choice for a normalizing flow is to use a Gaussian for the base
# distribution.  (However, any continuous distribution would work.) Here, we
# use `tfd.Sample` to create a joint Gaussian distribution with diagonal
# covariance for the base distribution (note that in the Gaussian case,
# `tfd.MultivariateNormalDiag` could also be used.)
maf = tfd.TransformedDistribution(
    distribution=tfd.Sample(
        tfd.Normal(loc=0., scale=1.), sample_shape=[dims]),
    bijector=tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
            params=2, hidden_units=[512, 512])))

x = maf.sample()  # Expensive; uses `tf.while_loop`, no Bijector caching.
maf.log_prob(x)   # Almost free; uses Bijector caching.
# Cheap; no `tf.while_loop` despite no Bijector caching.
maf.log_prob(tf.zeros(dims))

# [Papamakarios et al. (2016)][3] also describe an Inverse Autoregressive
# Flow [(Kingma et al., 2016)][2]:
iaf = tfd.TransformedDistribution(
    distribution=tfd.Sample(
        tfd.Normal(loc=0., scale=1.), sample_shape=[dims]),
    bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
            params=2, hidden_units=[512, 512]))))

x = iaf.sample()  # Cheap; no `tf.while_loop` despite no Bijector caching.
iaf.log_prob(x)   # Almost free; uses Bijector caching.
# Expensive; uses `tf.while_loop`, no Bijector caching.
iaf.log_prob(tf.zeros(dims))

# In many (if not most) cases the default `shift_and_log_scale_fn` will be a
# poor choice.  Here's an example of using a 'shift only' version and with a
# different number/depth of hidden layers.
made = tfb.AutoregressiveNetwork(params=1, hidden_units=[32])
maf_no_scale_hidden2 = tfd.TransformedDistribution(
    distribution=tfd.Sample(
        tfd.Normal(loc=0., scale=1.), sample_shape=[dims]),
    bijector=tfb.MaskedAutoregressiveFlow(
        lambda y: (made(y)[..., 0], None),
        is_constant_jacobian=True))
maf_no_scale_hidden2._made = made  # Ensure maf_no_scale_hidden2.trainable
# NOTE: The last line ensures that maf_no_scale_hidden2.trainable_variables
# will include all variables from `made`.

Variable Tracking

A tfb.MaskedAutoregressiveFlow instance saves a reference to the values passed as shift_and_log_scale_fn and bijector_fn to its constructor. Thus, for most values passed as shift_and_log_scale_fn or bijector_fn, variables referenced by those values will be found and tracked by the tfb.MaskedAutoregressiveFlow instance. Please see the tf.Module documentation for further details.

However, if the value passed to shift_and_log_scale_fn or bijector_fn is a Python function, then tfb.MaskedAutoregressiveFlow cannot automatically track variables used inside shift_and_log_scale_fn or bijector_fn. To get tfb.MaskedAutoregressiveFlow to track such variables, either:

  1. Replace the Python function with a tf.Module, tf.keras.Layer, or other callable object through which tf.Module can find variables.

  2. Or, add a reference to the variables to the tfb.MaskedAutoregressiveFlow instance by setting an attribute -- for example:

    ````
    made1 = tfb.AutoregressiveNetwork(params=1, hidden_units=[10, 10])
    made2 = tfb.AutoregressiveNetwork(params=1, hidden_units=[10, 10])
    maf = tfb.MaskedAutoregressiveFlow(lambda y: (made1(y), made2(y) + 1.))
    maf._made_variables = made1.variables + made2.variables
    ````
    

    References

    [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: Masked Autoencoder for Distribution Estimation. In International Conference on Machine Learning, 2015. https://arxiv.org/abs/1502.03509

    [2]: Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, and Max Welling. Improving Variational Inference with Inverse Autoregressive Flow. In Neural Information Processing Systems, 2016. https://arxiv.org/abs/1606.04934

    [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked Autoregressive Flow for Density Estimation. In Neural Information Processing Systems, 2017. https://arxiv.org/abs/1705.07057

    [4]: Diederik P Kingma, Tim Salimans, Max Welling. Improving Variational Inference with Inverse Autoregressive Flow. In Neural Information Processing Systems, 2016. https://arxiv.org/abs/1606.04934

    [5]: Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural Spline Flows, 2019. http://arxiv.org/abs/1906.04032

    shift_and_log_scale_fn Python callable which computes shift and log_scale from the inverse domain (y). Calculation must respect the 'autoregressive property' (see class docstring). Suggested default tfb.AutoregressiveNetwork(params=2, hidden_layers=...). Typically the function contains tf.Variables. Returning None for either (both) shift, log_scale is equivalent to (but more efficient than) returning zero. If shift_and_log_scale_fn returns a single Tensor, the returned value will be unstacked to get the shift and log_scale: tf.unstack(shift_and_log_scale_fn(y), num=2, axis=-1).
    bijector_fn Python callable which returns a tfb.Bijector which transforms event tensor with the signature (input, **condition_kwargs) -> bijector. The bijector must operate on scalar events and must not alter the rank of its input. The bijector_fn will be called with Tensors from the inverse domain (y). Calculation must respect the 'autoregressive property' (see class docstring).
    is_constant_jacobian Python bool. Default: False. When True the implementation assumes log_scale does not depend on the forward domain (x) or inverse domain (y) values. (No validation is made; is_constant_jacobian=False is always safe but possibly computationally inefficient.)
    validate_args Python bool indicating whether arguments should be checked for correctness.
    unroll_loop Python bool indicating whether the tf.while_loop in _forward should be replaced with a static for loop. Requires that the final dimension of x be known at graph construction time. Defaults to False.
    event_ndims Python integer, the intrinsic dimensionality of this bijector. 1 corresponds to a simple vector autoregressive bijector as implemented by the tfp.bijectors.AutoregressiveNetwork, 2 might be useful for a 2D convolutional shift_and_log_scale_fn and so on.
    name Python str, name given to ops managed by this object.

    ValueError If both or none of shift_and_log_scale_fn and bijector_fn are specified.

    dtype

    forward_min_event_ndims Returns the minimal number of dimensions bijector.forward operates on.

    Multipart bijectors return structured ndims, which indicates the expected structure of their inputs. Some multipart bijectors, notably Composites, may return structures of None.

    graph_parents Returns this Bijector's graph_parents as a Python list.
    inverse_min_event_ndims Returns the minimal number of dimensions bijector.inverse operates on.

    Multipart bijectors return structured event_ndims, which indicates the expected structure of their outputs. Some multipart bijectors, notably Composites, may return structures of None.

    is_constant_jacobian Returns true iff the Jacobian matrix is not a function of x.

    name Returns the string name of this Bijector.
    name_scope Returns a tf.name_scope instance for this class.
    non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
    parameters Dictionary of parameters used to instantiate this Bijector.
    submodules Sequence of all sub-modules.

    Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

    a = tf.Module()
    b = tf.Module()
    c = tf.Module()
    a.b = b
    b.c = c
    list(a.submodules) == [b, c]
    True
    list(b.submodules) == [c]
    True
    list(c.submodules) == []
    True
    

    trainable_variables Sequence of trainable variables owned by this module and its submodules.

    validate_args Returns True if Tensor arguments will be validated.
    variables Sequence of variables owned by this module and its submodules.

    Methods

    copy

    View source

    Creates a copy of the bijector.

    Args
    **override_parameters_kwargs String/value dictionary of initialization arguments to override with new values.

    Returns
    bijector A new instance of type(self) initialized from the union of self.parameters and override_parameters_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).

    experimental_batch_shape

    View source

    Returns the batch shape of this bijector for inputs of the given rank.

    The batch shape of a bijector decribes the set of distinct transformations it represents on events of a given size. For example: the bijector tfb.Scale([1., 2.]) has batch shape [2] for scalar events (event_ndims = 0), because applying it to a scalar event produces two scalar outputs, the result of two different scaling transformations. The same bijector has batch shape [] for vector events, because applying it to a vector produces (via elementwise multiplication) a single vector output.

    Bijectors that operate independently on multiple state parts, such as tfb.JointMap, must broadcast to a coherent batch shape. Some events may not be valid: for example, the bijector tfd.JointMap([tfb.Scale([1., 2.]), tfb.Scale([1., 2., 3.])]) does not produce a valid batch shape when event_ndims = [0, 0], since the batch shapes of the two parts are inconsistent. The same bijector does define valid batch shapes of [], [2], and [3] if event_ndims is [1, 1], [0, 1], or [1, 0], respectively.

    Since transforming a single event produces a scalar log-det-Jacobian, the batch shape of a bijector with non-constant Jacobian is expected to equal the shape of forward_log_det_jacobian(x, event_ndims=x_event_ndims) or inverse_log_det_jacobian(y, event_ndims=y_event_ndims), for x or y of the specified ndims.

    Args
    x_event_ndims Optional Python int (structure) number of dimensions in a probabilistic event passed to forward; this must be greater than or equal to self.forward_min_event_ndims. If None, defaults to self.forward_min_event_ndims. Mutually exclusive with y_event_ndims. Default value: None.
    y_event_ndims Optional Python int (structure) number of dimensions in a probabilistic event passed to inverse; this must be greater than or equal to self.inverse_min_event_ndims. Mutually exclusive with x_event_ndims. Default value: None.

    Returns
    batch_shape TensorShape batch shape of this bijector for a value with the given event rank. May be unknown or partially defined.

    experimental_batch_shape_tensor

    View source

    Returns the batch shape of this bijector for inputs of the given rank.

    The batch shape of a bijector decribes the set of distinct transformations it represents on events of a given size. For example: the bijector tfb.Scale([1., 2.]) has batch shape [2] for scalar events (event_ndims = 0), because applying it to a scalar event produces two scalar outputs, the result of two different scaling transformations. The same bijector has batch shape [] for vector events, because applying it to a vector produces (via elementwise multiplication) a single vector output.

    Bijectors that operate independently on multiple state parts, such as tfb.JointMap, must broadcast to a coherent batch shape. Some events may not be valid: for example, the bijector tfd.JointMap([tfb.Scale([1., 2.]), tfb.Scale([1., 2., 3.])]) does not produce a valid batch shape when event_ndims = [0, 0], since the batch shapes of the two parts are inconsistent. The same bijector does define valid batch shapes of [], [2], and [3] if event_ndims is [1, 1], [0, 1], or [1, 0], respectively.

    Since transforming a single event produces a scalar log-det-Jacobian, the batch shape of a bijector with non-constant Jacobian is expected to equal the shape of forward_log_det_jacobian(x, event_ndims=x_event_ndims) or inverse_log_det_jacobian(y, event_ndims=y_event_ndims), for x or y of the specified ndims.

    Args
    x_event_ndims Optional Python int (structure) number of dimensions in a probabilistic event passed to forward; this must be greater than or equal to self.forward_min_event_ndims. If None, defaults to self.forward_min_event_ndims. Mutually exclusive with y_event_ndims. Default value: None.
    y_event_ndims Optional Python int (structure) number of dimensions in a probabilistic event passed to inverse; this must be greater than or equal to self.inverse_min_event_ndims. Mutually exclusive with x_event_ndims. Default value: None.

    Returns
    batch_shape_tensor integer Tensor batch shape of this bijector for a value with the given event rank.

    experimental_compute_density_correction

    View source

    Density correction for this transformation wrt the tangent space, at x.

    Subclasses of Bijector may call the most specific applicable method of TangentSpace, based on whether the transformation is dimension-preserving, coordinate-wise, a projection, or something more general. The backward-compatible assumption is that the transformation is dimension-preserving (goes from R^n to R^n).

    Args
    x Tensor (structure). The point at which to calculate the density.
    tangent_space TangentSpace or one of its subclasses. The tangent to the support manifold at x.
    backward_compat bool specifying whether to assume that the Bijector is dimension-preserving.
    **kwargs Optional keyword arguments forwarded to tangent space methods.

    Returns
    density_correction Tensor representing the density correction---in log space---under the transformation that this Bijector denotes.

    Raises
    TypeError if backward_compat is False but no method of TangentSpace has been called explicitly.

    forward

    View source

    Returns the forward Bijector evaluation, i.e., X = g(Y).

    Args
    x Tensor (structure). The input to the 'forward' evaluation.
    name The name to give this op.
    **kwargs Named arguments forwarded to subclass implementation.

    Returns
    Tensor (structure).

    Raises
    TypeError if self.dtype is specified and x.dtype is not self.dtype.
    NotImplementedError if _forward is not implemented.

    forward_dtype

    View source

    Returns the dtype returned by forward for the provided input.

    forward_event_ndims

    View source

    Returns the number of event dimensions produced by forward.

    Args
    event_ndims Structure of Python and/or Tensor ints, and/or None values. The structure should match that of self.forward_min_event_ndims, and all non-None values must be greater than or equal to the corresponding value in self.forward_min_event_ndims.
    **kwargs Optional keyword arguments forwarded to nested bijectors.

    Returns
    forward_event_ndims Structure of integers and/or None values matching self.inverse_min_event_ndims. These are computed using 'prefer static' semantics: if any inputs are None, some or all of the outputs may be None, indicating that the output dimension could not be inferred (conversely, if all inputs are non-None, all outputs will be non-None). If all input event_ndims are Python ints, all of the (non-None) outputs will be Python ints; otherwise, some or all of the outputs may be Tensor ints.

    forward_event_shape

    View source

    Shape of a single sample from a single batch as a TensorShape.

    Same meaning as forward_event_shape_tensor. May be only partially defined.

    Args
    input_shape TensorShape (structure) indicating event-portion shape passed into forward function.

    Returns
    forward_event_shape_tensor TensorShape (structure) indicating event-portion shape after applying forward. Possibly unknown.

    forward_event_shape_tensor

    View source

    Shape of a single sample from a single batch as an int32 1D Tensor.

    Args
    input_shape Tensor, int32 vector (structure) indicating event-portion shape passed into forward function.
    name name to give to the op

    Returns
    forward_event_shape_tensor Tensor, int32 vector (structure) indicating event-portion shape after applying forward.

    forward_log_det_jacobian

    View source

    Returns both the forward_log_det_jacobian.

    Args
    x Tensor (structure). The input to the 'forward' Jacobian determinant evaluation.
    event_ndims Optional number of dimensions in the probabilistic events being transformed; this must be greater than or equal to self.forward_min_event_ndims. If event_ndims is specified, the log Jacobian determinant is summed to produce a scalar log-determinant for each event. Otherwise (if event_ndims is None), no reduction is performed. Multipart bijectors require structured event_ndims, such that the batch rank rank(y[i]) - event_ndims[i] is the same for all elements i of the structured input. In most cases (with the exception of tfb.JointMap) they further require that event_ndims[i] - self.inverse_min_event_ndims[i] is the same for all elements i of the structured input. Default value: None (equivalent to self.forward_min_event_ndims).
    name The name to give this op.
    **kwargs Named arguments forwarded to subclass implementation.

    Returns
    Tensor (structure), if this bijector is injective. If not injective this is not implemented.

    Raises
    TypeError if y's dtype is incompatible with the expected output dtype.
    NotImplementedError if neither _forward_log_det_jacobian nor {_inverse, _inverse_log_det_jacobian} are implemented, or this is a non-injective bijector.
    ValueError if the value of event_ndims is not valid for this bijector.

    inverse

    View source

    Returns the inverse Bijector evaluation, i.e., X = g^{-1}(Y).

    Args
    y Tensor (structure). The input to the 'inverse' evaluation.
    name The name to give this op.
    **kwargs Named arguments forwarded to subclass implementation.

    Returns
    Tensor (structure), if this bijector is injective. If not injective, returns the k-tuple containing the unique k points (x1, ..., xk) such that g(xi) = y.

    Raises
    TypeError if y's structured dtype is incompatible with the expected output dtype.
    NotImplementedError if _inverse is not implemented.

    inverse_dtype

    View source

    Returns the dtype returned by inverse for the provided input.

    inverse_event_ndims

    View source

    Returns the number of event dimensions produced by inverse.

    Args
    event_ndims Structure of Python and/or Tensor ints, and/or None values. The structure should match that of self.inverse_min_event_ndims, and all non-None values must be greater than or equal to the corresponding value in self.inverse_min_event_ndims.
    **kwargs Optional keyword arguments forwarded to nested bijectors.

    Returns
    inverse_event_ndims Structure of integers and/or None values matching self.forward_min_event_ndims. These are computed using 'prefer static' semantics: if any inputs are None, some or all of the outputs may be None, indicating that the output dimension could not be inferred (conversely, if all inputs are non-None, all outputs will be non-None). If all input event_ndims are Python ints, all of the (non-None) outputs will be Python ints; otherwise, some or all of the outputs may be Tensor ints.

    inverse_event_shape

    View source

    Shape of a single sample from a single batch as a TensorShape.

    Same meaning as inverse_event_shape_tensor. May be only partially defined.

    Args
    output_shape TensorShape (structure) indicating event-portion shape passed into inverse function.

    Returns
    inverse_event_shape_tensor TensorShape (structure) indicating event-portion shape after applying inverse. Possibly unknown.

    inverse_event_shape_tensor

    View source

    Shape of a single sample from a single batch as an int32 1D Tensor.

    Args
    output_shape Tensor, int32 vector (structure) indicating event-portion shape passed into inverse function.
    name name to give to the op

    Returns
    inverse_event_shape_tensor Tensor, int32 vector (structure) indicating event-portion shape after applying inverse.

    inverse_log_det_jacobian

    View source

    Returns the (log o det o Jacobian o inverse)(y).

    Mathematically, returns: log(det(dX/dY))(Y). (Recall that: X=g^{-1}(Y).)

    Note that forward_log_det_jacobian is the negative of this function, evaluated at g^{-1}(y).

    Args
    y Tensor (structure). The input to the 'inverse' Jacobian determinant evaluation.
    event_ndims Optional number of dimensions in the probabilistic events being transformed; this must be greater than or equal to self.inverse_min_event_ndims. If event_ndims is specified, the log Jacobian determinant is summed to produce a scalar log-determinant for each event. Otherwise (if event_ndims is None), no reduction is performed. Multipart bijectors require structured event_ndims, such that the batch rank rank(y[i]) - event_ndims[i] is the same for all elements i of the structured input. In most cases (with the exception of tfb.JointMap) they further require that event_ndims[i] - self.inverse_min_event_ndims[i] is the same for all elements i of the structured input. Default value: None (equivalent to self.inverse_min_event_ndims).
    name The name to give this op.
    **kwargs Named arguments forwarded to subclass implementation.

    Returns
    ildj Tensor, if this bijector is injective. If not injective, returns the tuple of local log det Jacobians, log(det(Dg_i^{-1}(y))), where g_i is the restriction of g to the ith partition Di.

    Raises
    TypeError if x's dtype is incompatible with the expected inverse-dtype.
    NotImplementedError if _inverse_log_det_jacobian is not implemented.
    ValueError if the value of event_ndims is not valid for this bijector.

    parameter_properties

    View source

    Returns a dict mapping constructor arg names to property annotations.

    This dict should include an entry for each of the bijector's Tensor-valued constructor arguments.

    Args
    dtype Optional float dtype to assume for continuous-valued parameters. Some constraining bijectors require advance knowledge of the dtype because certain constants (e.g., tfb.Softplus.low) must be instantiated with the same dtype as the values to be transformed.

    Returns
    parameter_properties A str ->tfp.python.internal.parameter_properties.ParameterPropertiesdict mapping constructor argument names toParameterProperties` instances.

    with_name_scope

    Decorator to automatically enter the module name scope.

    class MyModule(tf.Module):
      @tf.Module.with_name_scope
      def __call__(self, x):
        if not hasattr(self, 'w'):
          self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
        return tf.matmul(x, self.w)
    

    Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

    mod = MyModule()
    mod(tf.ones([1, 2]))
    <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
    mod.w
    <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
    numpy=..., dtype=float32)>
    

    Args
    method The method to wrap.

    Returns
    The original method wrapped such that it enters the module's name scope.

    __call__

    View source

    Applies or composes the Bijector, depending on input type.

    This is a convenience function which applies the Bijector instance in three different ways, depending on the input:

    1. If the input is a tfd.Distribution instance, return tfd.TransformedDistribution(distribution=input, bijector=self).
    2. If the input is a tfb.Bijector instance, return tfb.Chain([self, input]).
    3. Otherwise, return self.forward(input)

    Args
    value A tfd.Distribution, tfb.Bijector, or a (structure of) Tensor.
    name Python str name given to ops created by this function.
    **kwargs Additional keyword arguments passed into the created tfd.TransformedDistribution, tfb.Bijector, or self.forward.

    Returns
    composition A tfd.TransformedDistribution if the input was a tfd.Distribution, a tfb.Chain if the input was a tfb.Bijector, or a (structure of) Tensor computed by self.forward.

    Examples

    sigmoid = tfb.Reciprocal()(
        tfb.Shift(shift=1.)(
          tfb.Exp()(
            tfb.Scale(scale=-1.))))
    # ==> `tfb.Chain([
    #         tfb.Reciprocal(),
    #         tfb.Shift(shift=1.),
    #         tfb.Exp(),
    #         tfb.Scale(scale=-1.),
    #      ])`  # ie, `tfb.Sigmoid()`
    
    log_normal = tfb.Exp()(tfd.Normal(0, 1))
    # ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`
    
    tfb.Exp()([-1., 0., 1.])
    # ==> tf.exp([-1., 0., 1.])
    

__eq__

View source

Return self==value.

__getitem__

View source

__iter__

View source