|View source on GitHub|
Experimental module for doing distributed log prob calculations.
class JointDistributionCoroutine: A sharding-aware JointDistributionCoroutine.
class JointDistributionNamed: A sharding-aware JointDistributionNamed.
class JointDistributionSequential: A sharding-aware JointDistributionSequential.
class Sharded: A meta-distribution meant for use in an SPMD distributed context.
make_pbroadcast_function(...): Constructs a function that broadcasts inputs over named axes.
make_psum_function(...): Constructs a function that broadcasts inputs over named axes.
make_sharded_log_prob_parts(...): Constructs a log prob parts function that all-reduces over terms.