tfp.substrates.jax.distributions.independent_joint_distribution_from_structure

Turns a (potentially nested) structure of dists into a single dist.

structure_of_distributions instance of tfd.Distribution, or nested structure (tuple, list, dict, etc.) in which all leaves are tfd.Distribution instances.
validate_args Python bool. Whether the joint distribution should validate input with asserts. This imposes a runtime cost. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.

distribution instance of tfd.Distribution such that distribution.sample() is equivalent to tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions). If structure_of_distributions was indeed a structure (as opposed to a single Distribution instance), this will be a JointDistribution with the corresponding structure.

TypeError if any leaves of the input structure are not tfd.Distribution instances.