tfp.experimental.auto_batching.xla.compile_nested_output

Wraps f with a tpu.rewrite or xla.compile, propagates output structure.

xla.compile insists f output a flat list of Tensors or Ops, but tolerates nested input arguments. Here, we capture the output structure in order to propagate it.

f Callable to compile, may accept/return nested inputs/outputs.
compile_fn The function to use to compile, i.e. xla.compile or tpu.rewrite. Accepts two args, f and inputs.

g Callable wrapping f which returns XLA-compiled, nested outputs.