Constructs a log prob parts function that all-reduces over terms.
log_prob_parts_fn, is_sharded, axis_name=None
Given a log_prob_parts function, this function will return a new one that
includes all-reduce sums over terms according to the
is_sharded property. It
will also add all-reduce sums for the gradient of sharded terms w.r.t.
a callable that takes in a structured value and returns a
structure of log densities for each of the terms, that when summed returns
a locally correct log-density.
a structure of boolean values that matches the input and output
log_prob_parts_fn. If a value in
log_prob_parts_fn has a
is_sharded value set to
True, the returned function will
add an all-reduce sum for its term in the log prob calculation. If it is
False, the returned function will have an all-reduce sum over the
gradient of sharded terms w.r.t. to the unsharded value.
str used for the axis name in the JAX backend. Unused in the
A new log prob parts function that can be run inside of strategy.