Google I/O returns May 18-20! Reserve space and build your schedule Register now

Builds a variational posterior by linearly transforming base distributions.

This function builds a surrogate posterior by applying a trainable transformation to a base distribution (typically a tfd.JointDistribution) or nested structure of base distributions, and constraining the samples with bijector. Note that the distributions must have event shapes corresponding to the pretransformed surrogate posterior -- that is, if bijector contains a shape-changing bijector, then the corresponding base distribution event shape is the inverse event shape of the bijector applied to the desired surrogate posterior shape. The surrogate posterior is constucted as follows:

  1. Flatten the base distribution event shapes to vectors, and pack the base distributions into a tfd.JointDistribution.
  2. Apply a trainable blockwise LinearOperator bijector to the joint base distribution.
  3. Apply the constraining bijectors and return the resulting trainable tfd.TransformedDistribution instance.

base_distribution tfd.Distribution instance (typically a tfd.JointDistribution), or a nested structure of tfd.Distribution instances.
operators Either a string or a list/tuple containing LinearOperator subclasses, LinearOperator instances, or callables returning LinearOperator instances. Supported string values are "diag" (to create a mean-field surrogate posterior) and "tril" (to create a full-covariance surrogate posterior). A list/tuple may be passed to induce other posterior covariance structures. If the list is flat, a tf.linalg.LinearOperatorBlockDiag instance will be created and applied to the base distribution. Otherwise the list must be singly-nested and have a first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a tf.linalg.LinearOperatorBlockLowerTriangular instance is created. For complete documentation and examples, see, which receives the operators arg if it is list-like. Default value: "diag".
bijector tfb.Bijector instance, or nested structure of tfb.Bijector instances, that maps (nested) values in R^n to the support of the posterior. (This can be the experimental_default_event_space_bijector of the distribution over the prior latent variables.) Default value: None (i.e., the posterior is over R^n).
initial_unconstrained_loc_fn Optional Python callable with signature initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed) used to sample real-valued initializations for the unconstrained location of each variable. Default value: functools.partial(tf.random.stateless_uniform, minval=-2., maxval=2., dtype=tf.float32).
seed Python integer to seed the random number generator for initial values. Default value: None.
validate_args Python bool. Whether to validate input with asserts. This imposes a runtime cost. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.
name Python str name prefixed to ops created by this function. Default value: None (i.e., 'build_affine_surrogate_posterior_from_base_distribution').

surrogate_distribution Trainable tfd.JointDistribution instance.

NotImplementedError Base distributions with mixed dtypes are not supported.


tfd = tfp.distributions
tfb = tfp.bijectors

# Fit a multivariate Normal surrogate posterior on the Eight Schools model
# [1].

treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]

def model_fn():
  avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
  log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
  school_effects = yield tfd.Sample(
      tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
  treatment_effects = yield tfd.Independent(
      tfd.Normal(loc=school_effects, scale=treatment_stddevs),
model = tfd.JointDistributionCoroutineAutoBatched(model_fn)

# Pin the observed values in the model.
target_model = model.experimental_pin(treatment_effects=treatment_effects)

# Define a lower triangular structure of `LinearOperator` subclasses that
# models full covariance among latent variables except for the 8 dimensions
# of `school_effect`, which are modeled as independent (using
# `LinearOperatorDiag`).
operators = [
  [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
  [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,

# Constrain the posterior values to the support of the prior.
bijector = target_model.experimental_default_event_space_bijector()

# Build a full-covariance surrogate posterior.
surrogate_posterior = (

# Fit the model.
losses =


[1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013.