TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfp.experimental.auto_batching.xla.compile_nested_output

View source on GitHub

Wraps f with a tpu.rewrite or xla.compile, propagates output structure.

tfp.experimental.auto_batching.xla.compile_nested_output(
    f,
    compile_fn=None
)

xla.compile insists f output a flat list of Tensors or Ops, but tolerates nested input arguments. Here, we capture the output structure in order to propagate it.

Args:

  • f: Callable to compile, may accept/return nested inputs/outputs.
  • compile_fn: The function to use to compile, i.e. xla.compile or tpu.rewrite. Accepts two args, f and inputs.

Returns:

  • g: Callable wrapping f which returns XLA-compiled, nested outputs.