![]() |
Experimental module for doing distributed log prob calculations.
Classes
class JointDistributionCoroutine
: A sharding-aware JointDistributionCoroutine.
class JointDistributionNamed
: A sharding-aware JointDistributionNamed.
class JointDistributionSequential
: A sharding-aware JointDistributionSequential.
class ShardedIndependent
: A version of tfd.Independent
that folds device id into its randomness.
class ShardedSample
: A version of tfd.Sample
that shards its output across devices.
Functions
make_sharded_log_prob_parts(...)
: Constructs a log prob parts function that all-reduces over terms.