Google I/O returns May 18-20! Reserve space and build your schedule Register now

Builds a joint variational posterior with a given event_shape.

This function builds a surrogate posterior by applying a trainable transformation to a standard base distribution and constraining the samples with bijector. The surrogate posterior has event shape equal to the input event_shape.

This function is a convenience wrapper around build_affine_surrogate_posterior_from_base_distribution that allows the user to pass in the desired posterior event_shape instead of pre-constructed base distributions (at the expense of full control over the base distribution types and parameterizations).

event_shape (Nested) event shape of the posterior.
operators Either a string or a list/tuple containing LinearOperator subclasses, LinearOperator instances, or callables returning LinearOperator instances. Supported string values are "diag" (to create a mean-field surrogate posterior) and "tril" (to create a full-covariance surrogate posterior). A list/tuple may be passed to induce other posterior covariance structures. If the list is flat, a tf.linalg.LinearOperatorBlockDiag instance will be created and applied to the base distribution. Otherwise the list must be singly-nested and have a first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a tf.linalg.LinearOperatorBlockLowerTriangular instance is created. For complete documentation and examples, see, which receives the operators arg if it is list-like. Default value: "diag".
bijector tfb.Bijector instance, or nested structure of tfb.Bijector instances, that maps (nested) values in R^n to the support of the posterior. (This can be the experimental_default_event_space_bijector of the distribution over the prior latent variables.) Default value: None (i.e., the posterior is over R^n).
base_distribution A tfd.Distribution subclass parameterized by loc and scale. The base distribution of the transformed surrogate has loc=0. and scale=1.. Default value: tfd.Normal.
dtype The dtype of the surrogate posterior. Default value: tf.float32.
batch_shape Batch shape (Python tuple, list, or int) of the surrogate posterior, to enable parallel optimization from multiple initializations. Default value: ().
seed Python integer to seed the random number generator for initial values. Default value: None.
validate_args Python bool. Whether to validate input with asserts. This imposes a runtime cost. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.
name Python str name prefixed to ops created by this function. Default value: None (i.e., 'build_affine_surrogate_posterior').

surrogate_distribution Trainable tfd.Distribution with event shape equal to event_shape.


tfd = tfp.distributions
tfb = tfp.bijectors

# Define a joint probabilistic model.
Root = tfd.JointDistributionCoroutine.Root
def model_fn():
  concentration = yield Root(tfd.Exponential(1.))
  rate = yield Root(tfd.Exponential(1.))
  y = yield tfd.Sample(
      tfd.Gamma(concentration=concentration, rate=rate),
model = tfd.JointDistributionCoroutine(model_fn)

# Assume the `y` are observed, such that the posterior is a joint distribution
# over `concentration` and `rate`. The posterior event shape is then equal to
# the first two components of the model's event shape.
posterior_event_shape = model.event_shape_tensor()[:-1]

# Constrain the posterior values to be positive using the `Exp` bijector.
bijector = [tfb.Exp(), tfb.Exp()]

# Build a full-covariance surrogate posterior.
surrogate_posterior = (

# For an example defining `'operators'` as a list to express an alternative
# covariance structure, see
# `build_affine_surrogate_posterior_from_base_distribution`.

# Fit the model.
y = [0.2, 0.5, 0.3, 0.7]
target_model = model.experimental_pin(y=y)
losses =