Module: tfp.experimental.distribute

Experimental module for doing distributed log prob calculations.

Classes

class JointDistributionCoroutine: A JDMixin that shards the log_prob calculation.

class JointDistributionNamed: A JDMixin that shards the log_prob calculation.

class JointDistributionSequential: A JDMixin that shards the log_prob calculation.

class ShardedIndependent: A version of tfd.Independent that folds device id into its randomness.

class ShardedSample: A version of tfd.Sample that shards its output across devices.

Functions

make_sharded_log_prob_parts(...): Constructs a log prob parts function that all-reduces over terms.