Constructs a log prob parts function that all-reduces over terms.
tfp.experimental.distribute.make_sharded_log_prob_parts(
log_prob_parts_fn, axis_names
)
Given a log_prob_parts function, this function will return a new one that
includes all-reduce sums over terms according to the is_sharded
property. It
will also add all-reduce sums for the gradient of sharded terms w.r.t.
unsharded terms.
Args |
log_prob_parts_fn
|
a callable that takes in a structured value and returns a
structure of log densities for each of the terms, that when summed returns
a locally correct log-density.
|
axis_names
|
a structure of values that matches the input and output of
log_prob_parts_fn . Each value in axis_names is either None, a string
name of a mapped axis in the JAX backend or any non- Nonevalue in TF
backend, or an iterable thereof corresponding to multiple sharding axes.
If the axis_nameis not None, the returned function will add
all-reduce sum(s) for its term in the log prob calculation. If it is None`, the returned function will have an all-reduce sum over the
gradient of sharded terms w.r.t. to the unsharded value.
|
Returns |
A new log prob parts function that can be run inside of a strategy.
|