Save the date! Google I/O returns May 18-20 Register now


Returns an initialization Distribution for starting a Markov chain.

This initialization scheme follows Stan: we sample every latent independently, uniformly from -2 to 2 in its unconstrained space, and then transform into constrained space to construct an initial state that can be passed to sample_chain or other MCMC drivers.

The argument signature is arranged to let the user pass either a JointDistribution describing their model, if it's in that form, or the essential information necessary for the sampling, namely a bijector (from unconstrained to constrained space) and the desired shape and dtype of each sample (specified in constrained space).

model A Distribution (typically a JointDistribution) giving the model to be initialized. If supplied, it is queried for its default event space bijector, its event shape, and its dtype. If not supplied, those three elements must be supplied instead.
constraining_bijector A (typically multipart) Bijector giving the mapping from unconstrained to constrained space. If supplied together with a model, acts as an override. A nested structure of Bijectors is accepted, and interpreted as applying in parallel to a corresponding structure of state parts (see JointMap for details).
event_shapes A structure of shapes giving the (unconstrained) event space shape of the desired samples. Must be an acceptable input to constraining_bijector.inverse_event_shape. If supplied together with model, acts as an override.
dtypes A structure of dtypes giving the (unconstrained) dtypes of the desired samples. Must be an acceptable input to constraining_bijector.inverse_dtype. If supplied together with model, acts as an override.

init_dist A Distribution representing the initialization distribution, in constrained space. Samples from this Distribution are valid initial states for a Markov chain targeting the model.


Initialize 100 chains from the unconstrained -2, 2 distribution for a model expressed as a JointDistributionCoroutine:

def model():

init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model)
states = tfp.mcmc.sample_chain(
  current_state=init_dist.sample(100, seed=[4, 8]),