|View source on GitHub|
Experimental module for doing distributed log prob calculations.
class JointDistributionCoroutine: A sharding-aware JointDistributionCoroutine.
class JointDistributionNamed: A sharding-aware JointDistributionNamed.
class JointDistributionSequential: A sharding-aware JointDistributionSequential.
class ShardedIndependent: A version of
tfd.Independent that folds device id into its randomness.
class ShardedSample: A version of
tfd.Sample that shards its output across devices.
make_sharded_log_prob_parts(...): Constructs a log prob parts function that all-reduces over terms.