|View source on GitHub|
An arbitrary already-batched computation, a 'primitive operation'.
These are the items of work on which auto-batching is applied. The
function must accept and produce Tensors with a batch dimension,
and is free to stage any (batched) computation it wants.
function must use the same computation substrate
as the VM backend. That is, if the VM is staging to XLA, the
function will see XLA Tensor handles; if the VM is staging to
graph-mode TensorFlow, the
function will see TensorFlow Tensors;
The current values of the
vars_out are saved on their respective
stacks, and the results written to the new top.
The exact contract for
function is as follows:
- It will be invoked with a list of positional (only) arguments,
- Each argument will be a pattern of Tensors (meaning, either one
Tensor or a (potentially nested) list or tuple of Tensors),
corresponding to the
Type of that variable.
- Each Tensor in the argument will have the
given in the corresponding
TensorType, and an additional leading
- Some indices in the batch dimension may contain junk data, if the
corresponding threads are not executing this instruction [this is
subject to change based on the batch execution strategy].
function must return a pattern of Tensors, or objects
convertible to Tensors.
- The returned pattern must be compatible with the
- The Tensors in the returned pattern must have
compatible with the corresponding
- The returned Tensors will be broadcast into their respective
positions if necessary. The broadcasting includes the batch
dimension: Thus, a returned Tensor of insufficient rank (e.g., a
constant) will be broadcast across batch members. In particular,
a Tensor that carries the indended batch size but whose sub-batch
shape is too low rank will broadcast incorrectly, and will result
in an error.
- If the
function raises an exception, it will propagate and abort
the entire computation.
- Even in the TensorFlow backend, the
function will be staged
several times: at least twice during type inference (to ascertain
the shapes of the Tensors it likes to return, as a function of the
shapes of the Tensors it is given), and exactly once during
executable graph construction.
vars_in: list of strings. The names of the VM variables whose current values to pass to the
vars_out: Pattern of strings. The names of the VM variables where to save the results returned from
function: Python callable implementing the computation.
skip_push_mask: Set of strings, a subset of
vars_out. These VM variables will be updated in place rather than pushed.
__new__( _cls, vars_in, vars_out, function, skip_push_mask )
Create new instance of PrimOp(vars_in, vars_out, function, skip_push_mask)
Return a copy of