View source on GitHub |
Module for the harvest transformation.
This module contains a general-purpose set of tools for transforming functions with a specific side-effect mechanism into pure functions. The names of the transformations in this module are inspired by the Sow/Reap mechanism in Mathematica.
The harvest module exposes two main functions: sow
and harvest
. sow
is
used to tag values and harvest
can inject values into functions or pull out
tagged values.
harvest
is a very general purpose transformation purely focused on converting
functions that have special side-effects (defined using sow
) and
"functionalizing" them. Specifically, a function
f :: (x: X) -> Y
has a set of defined intermediates, or Sows
. This set
can be divided into intermediates you are "collecting" and intermediates you are
"injecting", or Reaps
and Plants
respectively. Functionalizing
f
now gives you harvest(f) :: (plants: Plants, x: X) -> Tuple[Y, Reaps]
.
Generally, most users will not need to use harvest
directly, but will use
wrappers around it.
sow
sow
is the function used to tag values in a function. It takes in a single
positional argument, value
, which is returned as an output, so sow
outside
of a tracing context behaves like the identity function, i.e.
sow(x, ...) == x
. It also takes in two mandatory keyword arguments,
tag
and name
. tag
is a string used to namespace intermediate values in a
function. For example, some intermediates may be useful for probabilistic
programming (samples), and others may be useful to logging (summaries). The tag
enables harvest
to interact with only one set of intermediates at a time.
The name
is a string that describes the value you are sow
-ing. Eventually,
when calling harvest
on a function, the name
is used as the identifier
for the intermediate value.
Finally, sow
takes in an optional string keyword argument mode
, which is by
default set to 'strict'
. The mode
of a sow
describes how it behaves when
the same name appears multiple times. In "strict" mode, sow
will error if the
same (tag, name)
appears more than once. Another option is 'append'
, in
which all sows of the same name will be appended into a growing array. Finally,
there is 'clobber'
, where only the final sown value for a given (tag, name)
will be returned. The final optional argument for sow
is key
, which will
automatically be tied-in to the output of sow
to introduce a fake
data-dependence. By default, it is None
.
sow_cond
sow_cond
is a variant of sow
, that takes an additional positional argument,
pred
. It supports a single mode
'cond_clobber'
, which is like clobber
,
but sows values conditionally on pred
, falling back on zeros if no sow took
place. This allows reaping values from loop iterations besides the final one.
harvest
harvest
is a function transformation that augments the behaviors of sow
s
in the function body. Recall, that by default, sow
s act as identity functions
and do not affect the semantics of a function. Harvesting f
produces a
function that can take advantage of sow
s present in its execution. harvest
is a function that takes in a function f
and a string tag
. harvest
will
only interact with sow
s whose tag matches the input tag
. The returned
function can interact with the sow
s in the function body in either of two
ways. The first is via "injection", where intermediate values in the function
values can be overridden. harvest(f)
takes in an additional initial argument,
plants
, a dictionary mapping names to values. Each name in plants
should
correspond to a sow
in f
, and while running harvest(f)
rather than using
the value at runtime for the sow
, we substitute in the value from the plants
dictionary. The other way in which harvest(f)
interacts with sow
s is that
if it encounters a sow
whose tag matches and whose name is not in
plants
, it will add the output of the sow
to a dictionary mapping the sow
name to its output, called reaps
. The reaps
dictionary, at the end of
harvest(f)
's execution, will contain the outputs of all sow
s whose values
were not injected, or "planted."
The general convention is that, for any given execution of
harvest(f, tag=tag)
, there will be no more remaining sows of the given tag
if the function were to be reharvested, i.e. if we were to nest harvests with
the same tag harvest(harvest(f, tag='some_tag'), tag='some_tag')
, the outer
harvest would have nothing to plant or to reap.
Examples:
Using sow
and harvest
def f(x):
y = sow(x + 1., tag='intermediate', name='y')
return y + 1.
# Injecting, or "planting" a value for `y`.
harvest(f, tag='intermediate')({'y': 0.}, 1.) # ==> (1., {})
harvest(f, tag='intermediate')({'y': 0.}, 5.) # ==> (1., {})
# Collecting , or "reaping" the value of `y`.
harvest(f, tag='intermediate')({}, 1.) # ==> (3., {'y': 2.})
harvest(f, tag='intermediate')({}, 5.) # ==> (7., {'y': 6.})
Using reap
and plant
.
reap
and plant
are simple wrappers around harvest
. reap
only pulls
intermediate values without injecting, and plant
only injects values without
collecting intermediate values.
def f(x):
y = sow(x + 1., tag='intermediate', name='y')
return y + 1.
# Injecting, or "planting" a value for `y`.
plant(f, tag='intermediate')({'y': 0.}, 1.) # ==> 1.
plant(f, tag='intermediate')({'y': 0.}, 5.) # ==> 1.
# Collecting , or "reaping" the value of `y`.
reap(f, tag='intermediate')(1.) # ==> {'y': 2.}
reap(f, tag='intermediate')(5.) # ==> {'y': 6.}
Sharp edges
harvest
has undefined semantics under autodifferentiation. If a function you're taking the gradient of has asow
, it might produce unintuitive results when harvested. To better control gradient semantics, you can usejax.custom_jvp
orjax.custom_vjp
. The current implementation sows primals and tangents in the JVP but ignore cotangents in the VJP. These particular semantics are subject to change.- Planting values into a
pmap
is partially working. Harvest tries to plant all the values, assuming they have a leading map dimension.
Classes
class HarvestTrace
: An evaluating trace that dispatches to a dynamic context.
class HarvestTracer
: A HarvestTracer
just encapsulates a single value.
Functions
call_and_reap(...)
: Transforms a function into one that additionally returns its sown values.
nest(...)
: Wraps a function to create a new scope for harvested values.
plant(...)
: Transforms a function into one that injects values in place of sown ones.
reap(...)
: Transforms a function into one that returns its sown values.
sow(...)
: Marks a value with a name and a tag.