View source on GitHub |
Context object for auto-batching multiple Python functions together.
tfp.experimental.auto_batching.Context()
Usage:
ctx = frontend.Context()
@ctx.batch(type_inference=lambda ...)
def my_single_example_function_1(args):
...
@ctx.batch(type_inference=lambda ...)
def my_single_example_function_2(args):
...
# etc
Then calling any of the decorated functions will execute a batch computation.
The decorated functions may call each other, including mutually recursively.
See also the batch
method.
Limitations:
- You must explicitly decorate every function to be auto-batched.
- All calls to them must call them by name (no higher-order auto-batching).
- Auto-batched functions must be defined with
def
, notlambda
.
Methods
batch
batch(
type_inference
)
Decorates one function to auto-batch.
The decorated function will run in batch. It accepts all the same arguments, except:
- All arguments must have an additional leading dimension for the batch. (By special dispensation, scalar inputs are promoted to shape [1], which then leads to broadcasting.)
- All the arguments' sizes in the batch dimension must be the same, or 1. The latter are broadcast.
- The returned value will also have a leading batch dimension, and will have the same size.
- The batched function accepts an additional
bool
keyword argumentdry_run
. If present andTrue
, just calls the unbatched version, circumventing the auto-batching system. This can be useful for debugging the program subject to auto-batching. - The batched function accepts an additional
bool
keyword argumentstackless
. If present andTrue
, invokes the stackless version of the auto-batching system. This can be useful for avoiding stack maintenance overhead; but in general, it will recover less batching, and not work in graph-mode TensorFlow. - The batched function accepts an additional
int
keyword argumentmax_stack_depth
specifying the maximum stack depth (default 15). Ignored in stackless execution. - The batched function accepts an additional keyword argument
backend
specifying the backend to use. Must be an instance ofauto_batching.TensorFlowBackend
(default) orauto_batching.NumpyBackend
. - The batched function accepts an additional keyword argument
block_code_cache
, a dict which allows the caching of basic block rewrites (i.e.tf.function
+ XLA) to live across calls to the autobatched function. The default value ofNone
results in caching only within a given call to the batched function. Currently, stackless autobatching ignores the cache completely.
Args | |
---|---|
type_inference
|
A Python callable giving the type signature of the
function being auto-batched. The callable will be invoked with a single
argument giving the list of instructions.Type objects describing the
arguments at a particular call site, and must return a list of
instructions.Type objects describing the values that call site will
return.
|
Returns | |
---|---|
dec
|
A decorator that may be applied to a function with the given type signature to auto-batch it. |
Raises | |
---|---|
ValueError
|
If the decorated function predictably cannot be auto-batched,
e.g., name-clashing with another function already decorated in this
Context .
|
batch_uncurried
batch_uncurried(
function, type_inference
)
A non-decorator version of batch
, which see.
function_names
function_names()
lowered_for_args
lowered_for_args(
name, args, backend
)
Helper for calling program_lowered that computes the type signature.
module
module()
Constructs an instructions.Module
for this Context
.
Returns | |
---|---|
module
|
An instructions.Module representing the batched computation
defined by all the functions decorated with batch in this Context so
far.
|
program
program(
main
)
Constructs an instructions.Program
for this Context
.
This is a helper method, equivalent to self.module().program(main)
.
Args | |
---|---|
main
|
Python string name of the function that should be the entry point. |
Returns | |
---|---|
prog
|
An instructions.Program representing the batched computation
defined by all the functions decorated with batch in this Context so
far. Suitable for downstream compilation with other passes in
auto_batching .
|
Raises | |
---|---|
ValueError
|
If the intended main function was not decorated with
batch .
|
program_compiled
program_compiled(
main, sig=None, backend=None
)
Constructs a compiled instructions.Program
for this Context
.
This constructs the program with self.program(main)
, and the performs type
inference and optimization, to emit a result that can be executed by the
stackless auto-batching VM.
The point of having this as a method in its own right is that it caches the compilation on the types of the arguments.
If either sig
or backend
are omitted or None
, type inference is
skipped. The result is not executable, but it can be enlightening to
inspect.
Args | |
---|---|
main
|
Python string name of the function that should be the entry point. |
sig
|
A list of (patterns of) instructions.TensorType aligned with
the formal parameters to main .
|
backend
|
Backend implementation. |
Returns | |
---|---|
prog
|
An instructions.Program representing the batched computation
defined by all the functions decorated with batch in this Context so
far. Suitable for execution or staging on real data by the
auto-batching VM.
|
program_lowered
program_lowered(
main, sig=None, backend=None
)
Constructs a lowered instructions.Program
for this Context
.
This constructs the program with self.program(main)
, and the performs type
inference, optimization, and lowering, to emit a result that can be executed
(or staged) by the auto-batching VM.
The point of having this as a method in its own right is that it caches the compilation on the types of the arguments.
If either sig
or backend
are omitted or None
, type inference is
skipped. The result is not executable, but it can be enlightening to
inspect.
Args | |
---|---|
main
|
Python string name of the function that should be the entry point. |
sig
|
A list of (patterns of) instructions.TensorType aligned with
the formal parameters to main .
|
backend
|
Backend implementation. |
Returns | |
---|---|
prog
|
An instructions.Program representing the batched computation
defined by all the functions decorated with batch in this Context so
far. Suitable for execution or staging on real data by the
auto-batching VM.
|