|View source on GitHub|
An arbitrary already-batched computation, a 'primitive operation'.
@staticmethod tfp.experimental.auto_batching.instructions.PrimOp( _cls, vars_in, vars_out, function, skip_push_mask )
These are the items of work on which auto-batching is applied. The
function must accept and produce Tensors with a batch dimension,
and is free to stage any (batched) computation it wants.
function must use the same computation substrate
as the VM backend. That is, if the VM is staging to XLA, the
function will see XLA Tensor handles; if the VM is staging to
graph-mode TensorFlow, the
function will see TensorFlow Tensors;
The current values of the
vars_out are saved on their respective
stacks, and the results written to the new top.
The exact contract for
function is as follows:
- It will be invoked with a list of positional (only) arguments,
- Each argument will be a pattern of Tensors (meaning, either one
Tensor or a (potentially nested) list or tuple of Tensors),
corresponding to the
Typeof that variable.
- Each Tensor in the argument will have the
shapegiven in the corresponding
TensorType, and an additional leading batch dimension.
- Some indices in the batch dimension may contain junk data, if the corresponding threads are not executing this instruction [this is subject to change based on the batch execution strategy].
functionmust return a pattern of Tensors, or objects convertible to Tensors.
- The returned pattern must be compatible with the
- The Tensors in the returned pattern must have
shapecompatible with the corresponding
- The returned Tensors will be broadcast into their respective positions if necessary. The broadcasting includes the batch dimension: Thus, a returned Tensor of insufficient rank (e.g., a constant) will be broadcast across batch members. In particular, a Tensor that carries the indended batch size but whose sub-batch shape is too low rank will broadcast incorrectly, and will result in an error.
- If the
functionraises an exception, it will propagate and abort the entire computation.
- Even in the TensorFlow backend, the
functionwill be staged several times: at least twice during type inference (to ascertain the shapes of the Tensors it likes to return, as a function of the shapes of the Tensors it is given), and exactly once during executable graph construction.
vars_in: list of strings. The names of the VM variables whose current values to pass to the
vars_out: Pattern of strings. The names of the VM variables where to save the results returned from
function: Python callable implementing the computation.
skip_push_mask: Set of strings, a subset of
vars_out. These VM variables will be updated in place rather than pushed.
replace( vars_out=None )
Return a copy of