View source on GitHub |
Constructs a function that broadcasts inputs over named axes.
tfp.experimental.distribute.make_psum_function(
fn, in_axes, out_axes, out_dtype
)
Given a function fn
, make_psum_function
returns a new one that
includes psums over terms according to axis names provided in out_axes
. It
also adds psums for the vector-Jacobian product of the outputs of fn
w.r.t.
its inputs according to in_axes
if there are axes in the outputs that are
not present in an input.
Returns | |
---|---|
A new function that applies psums on to the output of the original function and corrects the gradient with respect to its inputs. |