TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfp.experimental.auto_batching.instructions.PrimOp

View source on GitHub

Class PrimOp

An arbitrary already-batched computation, a 'primitive operation'.

Aliases:

These are the items of work on which auto-batching is applied. The function must accept and produce Tensors with a batch dimension, and is free to stage any (batched) computation it wants. Restriction: the function must use the same computation substrate as the VM backend. That is, if the VM is staging to XLA, the function will see XLA Tensor handles; if the VM is staging to graph-mode TensorFlow, the function will see TensorFlow Tensors; etc.

The current values of the vars_out are saved on their respective stacks, and the results written to the new top.

The exact contract for function is as follows: - It will be invoked with a list of positional (only) arguments, parallel to vars_in. - Each argument will be a pattern of Tensors (meaning, either one Tensor or a (potentially nested) list or tuple of Tensors), corresponding to the Type of that variable. - Each Tensor in the argument will have the dtype and shape given in the corresponding TensorType, and an additional leading batch dimension. - Some indices in the batch dimension may contain junk data, if the corresponding threads are not executing this instruction [this is subject to change based on the batch execution strategy]. - The function must return a pattern of Tensors, or objects convertible to Tensors. - The returned pattern must be compatible with the Types of vars_out. - The Tensors in the returned pattern must have dtype and shape compatible with the corresponding TensorTypes of vars_out. - The returned Tensors will be broadcast into their respective positions if necessary. The broadcasting includes the batch dimension: Thus, a returned Tensor of insufficient rank (e.g., a constant) will be broadcast across batch members. In particular, a Tensor that carries the indended batch size but whose sub-batch shape is too low rank will broadcast incorrectly, and will result in an error. - If the function raises an exception, it will propagate and abort the entire computation. - Even in the TensorFlow backend, the function will be staged several times: at least twice during type inference (to ascertain the shapes of the Tensors it likes to return, as a function of the shapes of the Tensors it is given), and exactly once during executable graph construction.

Args:

  • vars_in: list of strings. The names of the VM variables whose current values to pass to the function.
  • vars_out: Pattern of strings. The names of the VM variables where to save the results returned from function.
  • function: Python callable implementing the computation.
  • skip_push_mask: Set of strings, a subset of vars_out. These VM variables will be updated in place rather than pushed.

__new__

__new__(
    _cls,
    vars_in,
    vars_out,
    function,
    skip_push_mask
)

Create new instance of PrimOp(vars_in, vars_out, function, skip_push_mask)

Properties

vars_in

vars_out

function

skip_push_mask

Methods

replace

View source

replace(vars_out=None)

Return a copy of self with vars_out replaced.