View source on GitHub |
Experimental module for doing distributed log prob calculations.
Classes
class JointDistributionCoroutine
: A sharding-aware JointDistributionCoroutine.
class JointDistributionNamed
: A sharding-aware JointDistributionNamed.
class JointDistributionSequential
: A sharding-aware JointDistributionSequential.
class Sharded
: A meta-distribution meant for use in an SPMD distributed context.
Functions
make_pbroadcast_function(...)
: Constructs a function that broadcasts inputs over named axes.
make_psum_function(...)
: Constructs a function that broadcasts inputs over named axes.
make_sharded_log_prob_parts(...)
: Constructs a log prob parts function that all-reduces over terms.