View source on GitHub |
Constructs a function that broadcasts inputs over named axes.
tfp.experimental.distribute.make_pbroadcast_function(
fn, in_axes, out_axes, out_dtype
)
Given a function fn
, make_pbroadcast_function
returns a new one that
applies pbroadcast
to input terms according to axis names provided in
in_axes
and out_axes
. For each output axis in each term out the output of
fn
, inputs that do not have the output axes present are pbroadcasted before
that term is computed.
Returns | |
---|---|
A new function that applies pbroadcasts to the inputs of the original function. |