tfp.experimental.auto_batching.instructions.PrimOp

An arbitrary already-batched computation, a 'primitive operation'.

These are the items of work on which auto-batching is applied. The function must accept and produce Tensors with a batch dimension, and is free to stage any (batched) computation it wants. Restriction: the function must use the same computation substrate as the VM backend. That is, if the VM is staging to XLA, the function will see XLA Tensor handles; if the VM is staging to graph-mode TensorFlow, the function will see TensorFlow Tensors; etc.

The current values of the vars_out are saved on their respective stacks, and the results written to the new top.

The exact contract for function is as follows:

  • It will be invoked with a list of positional (only) arguments, parallel to vars_in.
  • Each argument will be a pattern of Tensors (meaning, either one Tensor or a (potentially nested) list or tuple of Tensors), corresponding to the Type of that variable.
  • Each Tensor in the argument will have the dtype and shape given in the corresponding TensorType, and an additional leading batch dimension.
  • Some indices in the batch dimension may contain junk data, if the corresponding threads are not executing this instruction [this is subject to change based on the batch execution strategy].
  • The function must return a pattern of Tensors, or objects convertible to Tensors.
  • The returned pattern must be compatible with the Types of vars_out.
  • The Tensors in the returned pattern must have dtype and shape compatible with the corresponding TensorTypes of vars_out.
  • The returned Tensors will be broadcast into their respective positions if necessary. The broadcasting includes the batch dimension: Thus, a returned Tensor of insufficient rank (e.g., a constant) will be broadcast across batch members. In particular, a Tensor that carries the indended batch size but whose sub-batch shape is too low rank will broadcast incorrectly, and will result in an error.
  • If the function raises an exception, it will propagate and abort the entire computation.
  • Even in the TensorFlow backend, the function will be staged several times: at least twice during type inference (to ascertain the shapes of the Tensors it likes to return, as a function of the shapes of the Tensors it is given), and exactly once during executable graph construction.

vars_in list of strings. The names of the VM variables whose current values to pass to the function.
vars_out Pattern of strings. The names of the VM variables where to save the results returned from function.
function Python callable implementing the computation.
skip_push_mask Set of strings, a subset of vars_out. These VM variables will be updated in place rather than pushed.

vars_in A namedtuple alias for field number 0
vars_out A namedtuple alias for field number 1
function A namedtuple alias for field number 2
skip_push_mask A namedtuple alias for field number 3

Methods

replace

View source

Return a copy of self with vars_out replaced.