TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfp.experimental.auto_batching.dsl.ProgramBuilder

View source on GitHub

Class ProgramBuilder

An auto-batching DSL context.

Aliases:

Auto-batching DSL operations are methods on the ProgramBuilder object. It's used like this:

ab = dsl.ProgramBuilder()

def fib_type(arg_types):
  return arg_types[0]

with ab.function(type_inference=fib_type) as fibonacci:
  n = ab.param('n')
  ab.var.cond = ab.primop(lambda n: n > 1)
  with ab.if_(ab.var.cond):
    ab.var.nm1 = ab.primop(lambda n: n - 1)
    ab.var.fibm1 = ab.call(fibonacci, [ab.var.nm1])
    ab.var.nm2 = ab.primop(lambda n: n - 2)
    ab.var.fibm2 = ab.call(fibonacci, [ab.var.nm2])
    ab.var.ans = ab.primop(lambda fibm1, fibm2: fibm1 + fibm2)
  with ab.else_():
    ab.var.ans = ab.const(1)
  ab.return_(ab.var.ans)

prog = ab.program(main=fibonacci)
# Now `prog` is a well-formed `instructions.Program`, and the context
# `ab` is no longer needed.

Note the sequence of method calls on ProgramBuilder corresponds to the source code of the Program being defined, not its runtime behavior. This is because (a) functions are defined with a context manager (rather than a Python function) which executes its body immediately and exactly once; and (b) function call instructions (and primitive operations) are just recorded, not entered recursively.

__init__

View source

__init__()

Creates an empty ProgramBuilder.

Properties

var

Auto-batching variables visible in the current scope.

Overrides setattr and getattr to provide a smooth interface to reading and defining variables:

  • ProgramBuilder.var.foo = ProgramBuilder.{call,primop} records an assignment to the auto-batched variable foo, possibly binding it, and

  • ProgramBuilder.var.foo reads from the auto-batched variable foo (if it is bound).

Example:

ab = dsl.ProgramBuilder()

ab.var.seven = ab.const(7)

Returns:

  • vars: A _MagicVars instance representing the local scope as above.

Methods

__call__

View source

__call__(pattern)

Prepares a multi-value return.

Example:

ab = dsl.ProgramBuilder()

ab((ab.var.two, ab.var.four)).pattern = ab.const((2, 4))

The protocol is to create a magic pattern object by invoking the ProgramBuilder as a callable, passing the pattern to bind; then assigning the pattern attribute of the returned value to the operation whose values to accept.

This is like this to work around limitations of embedding a DSL into Python: the assignment syntax = can be overridden only for fields of objects, not for function calls. It would have been nicer to implement ab.pattern(...) = ... but that's syntactically invalid Python. Hence, putting the pattern token at the end of the phrase rather than the beginning.

Args:

  • pattern: A pattern of variables (e.g., from ab.var.name) to bind.

Returns:

  • pat_object: A _MagicPattern instance representing the putative binding. Invoke the pattern = attribute setter on that instance to actually bind this pattern as the output of a primop, const, or call.

call

View source

call(
    function,
    vars_in,
    vars_out=None
)

Registers a function call instruction.

Example:

ab = dsl.ProgramBuilder()

# Define a function
with ab.function(...) as func:
  ...
  # Call it (recursively)
  ab.var.thing = ab.call(func, ...)
  ...

Args:

  • function: The instructions.Function object representing the function to call.
  • vars_in: Python strings giving the variables to pass in as inputs.
  • vars_out: A pattern of Python strings, giving the auto-batched variable(s) to which to write the result of the call. Defaults to the empty list.

Raises:

  • ValueError: If the call references undefined auto-batched variables.

Returns:

  • op: An instructions.FunctionCallOp representing the call. If one subsequently assigns this to a local, via ProgramBuilder.var.foo = op, that local gets added to the list of output variables.

const

View source

const(
    value,
    vars_out=None
)

Records a constant or set of constants.

Like primop, the output variables can be specified explicitly via the vars_out argument or implicitly by assigning the return value to some ProgramBuilder.var.foo.

Args:

  • value: A Python list of the constants to record.
  • vars_out: A pattern of Python strings, giving the auto-batched variable(s) to which to write the result of the callable. Defaults to the empty list.

Returns:

  • op: An instructions.PrimOp instance representing this operation. If one subsequently assigns this to a local, via ProgramBuilder.var.foo = op, that local gets added to the list of output variables.

declare_function

View source

declare_function(
    name=None,
    type_inference=None
)

Forward-declares a function to be defined later with define_function.

This useful for defining mutually recursive functions:

ab = dsl.ProgramBuilder()

foo = ab.declare_function(...)

with ab.function(...) as bar:
  ...
  ab.call(foo)

with ab.define_function(foo):
  ...
  ab.call(bar)

It is an error to call but never define a declared function.

Args:

  • name: Optional string naming this function when the program is printed.
  • type_inference: A Python callable giving the type signature of the function being defined. See function.

Returns:

  • function: An instructions.Function object representing the function being declared. It can be passed to call to call it, and to define_function to define it.

define_function

define_function(
    *args,
    **kwds
)

Registers a definition for a previously declared function.

Usually, one would use the function method to declare and define a function at the same time. Explicit use of define_function is only useful for mutual recursion or controlling code order separately from the call graph.

Example:

ab = dsl.ProgramBuilder()

foo = ab.declare_function(...)

with ab.function(...) as bar:
  ...
  ab.call(foo)

with ab.define_function(foo):
  ...
  ab.call(bar)

Function bodies appear in the compiled instructions.Program in order of definition, not declaration.

Note:

  • The formal parameters are given by calling ab.param inside the with block.
  • The body of the with block registers the body of the function being defined.
  • The last statement registered in the with block must be a ab.return_, or the Function will be malformed.

Args:

  • function: The function (from declare_function) to define.

Yields:

  • function: The function being defined, by symmetry with the context.function method.

Raises:

  • ValueError: If invoked while defining a function, if the function argument has already been defined, or if the function definition does not end in a return_.

else_

else_(
    *args,
    **kwds
)

Records the false branch of a conditional operation.

The true branch must be recorded (by if_, above) as the immediately preceding operation at the same nesting depth.

Example:

ab = dsl.ProgramBuilder()

ab.var.false = ab.const(False)
with ab.if_(ab.var.false):
  ...
with ab.else_():
  ...  # The body of the `with` statement gives the `false` branch

Args:

  • else_name: Optional Python string naming the false branch when the program is printed. Overrides the else_name, if any, given in the corresponding if_.
  • continue_name: Optional Python string naming the continuation after the if when the program is printed. Overrides the continue_name, if any, given in the corresponding if_.

Raises:

  • ValueError: If not immediately preceded by an if_.

Yields:

Nothing.

function

function(
    *args,
    **kwds
)

Registers a definition of an auto-batchable function.

Example:

ab = dsl.ProgramBuilder()

with ab.function(...) as f:
  ab.param('n')
  ...
  ab.return_(...)

Note:

  • The as clause (here f) binds an instructions.Function object representing the function being defined (see Yields).
  • The formal parameters are given by calling param inside the with block.
  • The body of the with block registers the body of the function being defined.
  • The last statement registered in the with block must be a call to return_, or the Function will be malformed.

The function method is a shorthand of declare_function followed by define_function. The example is equivalent to:

ab = dsl.ProgramBuilder()

f = ab.declare_function(...)
with ab.define_function(f):
  ab.param('n')
  ...
  ab.return_(...)

Args:

  • name: Optional string naming this function when the program is printed.
  • type_inference: A Python callable giving the type signature of the function being defined. The callable will be invoked with a single argument giving the list of instruction.Type objects describing the arguments at a particular call site, and must return a list of instruction.Type objects describing the values that call site will return.

Raises:

  • ValueError: If invoked while defining a function, or if the function definition does not end in a return_.

Yields:

  • function: An instructions.Function object representing the function being defined. It can be passed to call to call it (including recursively). Note that Python scopes as bindings to the definition enclosing the with, so a function thus bound can be referred to after its body as well.

if_

if_(
    *args,
    **kwds
)

Records a conditional operation and true first branch.

The false branch, if present, must be guarded by a call to else_, below.

Example:

ab = dsl.ProgramBuilder()

ab.var.true = ab.const(True)
with ab.if_(ab.var.true):
  ...  # The body of the `with` statement gives the `true` branch
with ab.else_():  # The else_ clause is optional
  ...

Args:

  • condition: Python string giving the boolean variable that holds the branch condition.
  • then_name: Optional Python string naming the true branch when the program is printed.
  • else_name: Optional Python string naming the false branch when the program is printed.
  • continue_name: Optional Python string naming the continuation after the if when the program is printed.

Yields:

Nothing.

Raises:

  • ValueError: If trying to condition on a variable that has not been written to.

local

View source

local(
    name=None,
    define=True
)

Declares a local variable in the current scope.

This should typically not be needed, because ProgramBuilder.var.foo = can bind variables; however, may be helpful for a multivalue return (see primop or call).

Args:

  • name: Optional Python string to serve a mnemonic name in later compiler stages. Variable names are automatically uniqued. This variable can later be referred to with ProgramBuilder.var.name, as well as through any Python binding of the returned value.
  • define: Boolean giving whether to mark this variable defined on creation. Default True. Setting False is useful for speculatively uniquing a variable on its first appearance, before knowning whether said appearance is a write (in which case the variable becomes defined) or a read (which raises an error).

Returns:

  • var: A Python string representing this variable. Suitable for passing to primop, call, if_, and return_.

locals_

View source

locals_(
    count,
    name=None
)

Declares several variables at once.

This is a convenience method standing for several invocations of local.

Args:

  • count: Python int. The number of distinct variables to return.
  • name: Optional Python string to serve a mnemonic name in later compiler stages. Variable names are automatically uniqued.

Returns:

  • vars: A list of count Python strings representing these variables. Suitable for passing to primop, call, if_, and return_.

module

View source

module()

Returns the registered function definitions as an instructions.Module.

Example:

ab = dsl.ProgramBuilder()

with ab.function(...) as foo:
  ...  # Do stuff

module = ab.module()

Raises:

  • ValueError: If invoked inside a function definition.

Returns:

  • module: The instructions.Module corresponding to all the definitions accumulated in this ProgramBuilder.

param

View source

param(name=None)

Declares a parameter of the function currently being defined.

This make a local variable like local, but also makes it an input of the nearest enclosing function (created by with ProgramBuilder.function()). This is a separate method from function because the DSL wants to create Python bindings for the function name itself and all of its input parameters, and there is no way to convince the with syntax to do that.

Args:

  • name: Optional Python string to serve a mnemonic name in later compiler stages. Variable names are automatically uniqued.

Returns:

  • var: A Python string representing this variable. Suitable for passing to primop, call, if_, and return_.

primop

View source

primop(
    f,
    vars_in=None,
    vars_out=None
)

Records a primitive operation.

Example:

ab = dsl.ProgramBuilder()

ab.var.five = ab.const(5)
# Implicit output binding
ab.var.ten = ab.primop(lambda five: five + five)
# Explicit output binding
ab.primop(lambda: (5, 10), vars_out=[ab.var.five, ab.var.ten])

Args:

  • f: A Python callable, the primitive operation to perform. Can be an inline lambda expression in simple cases. Must return a list or tuple of results, one for each intended output variable.
  • vars_in: A list of Python strings, giving the auto-batched variables to pass into the callable when invoking it. If absent, primop will try to infer it by inspecting the argument list of the callable and matching against variables bound in the local scope.
  • vars_out: A pattern of Python strings, giving the auto-batched variable(s) to which to write the result of the callable. Defaults to the empty list.

Raises:

  • ValueError: If the definition is invalid, if the primop references undefined auto-batched variables, or if auto-detection of input variables fails.

Returns:

  • op: An instructions.PrimOp instance representing this operation. If one subsequently assigns this to a local, via ProgramBuilder.var.foo = op, that local becomes the output pattern.

program

View source

program(main)

Returns the registered program as an instructions.Program.

This is a helper method, equivalent to self.module().program(main).

Example:

ab = dsl.ProgramBuilder()

with ab.function(...) as main:
  ...  # Do the stuff

program = ab.program(main)

Args:

Raises:

  • ValueError: If invoked inside a function definition, of if the intended main function was not defined.

Returns:

  • program: The instructions.Program corresponding to all the definitions accumulated in this ProgramBuilder.

return_

View source

return_(vars_out)

Records a function return instruction.

Example:

ab = dsl.ProgramBuilder()

with ab.function(...) as f:
  ...
  ab.var.result = ...
  ab.return_(ab.var.result)

A return_ command must occur at the top level of the function definition (not inside any if_s), and must be the last statement therein. You can always achieve this by assigning to a dedicated variable for the answer where you would otherwise return (and massaging your control flow).

Args:

  • vars_out: Pattern of Python strings giving the auto-batched variables to return.

Raises:

  • ValueError: If invoked more than once in a function body, or if trying to return variables that have not been written to.