|View source on GitHub|
Constructs a function that broadcasts inputs over named axes.
tfp.experimental.distribute.make_psum_function( fn, in_axes, out_axes, out_dtype )
Given a function
make_psum_function returns a new one that
includes psums over terms according to axis names provided in
also adds psums for the vector-Jacobian product of the outputs of
its inputs according to
in_axes if there are axes in the outputs that are
not present in an input.
|A new function that applies psums on to the output of the original function and corrects the gradient with respect to its inputs.|