tfp.experimental.auto_batching.stack_optimization.fuse_pop_push

Fuses pop+push sequences in the given Program.

A stack pop followed by a stack push (with no intervening read) is equivalent to just updating the top of the stack. The latter is more efficient for FULL variables, because it just updates the cache for the top, and avoids gathering from and scattering to the backing stack Tensor.

This pass mutates the ControlFlowGraph of the input Program to convert pop+push sequences into updates. The pass will work despite intervening instructions that interact with other Variables, but will not cross basic block boundaries. As a side-effect, the pass moves non-optimized pops to the last place in their basic block where they are still sound. This has no effect on the runtime behavior of the program.

program A lowered Program whose pop+push sequences to fuse. Blocks in the program may be mutated.

fused A Program with statically redundant pop+push sequences eliminated in favor of PrimOps with non-trivial skip_push_mask fields.

ValueError If the input Program has not been lowered (i.e., contains FunctionCallOp), or is ill-formed (e.g., contains invalid instructions).