Module: oryx.core.interpreters.harvest

Module for the harvest transformation.

This module contains a general-purpose set of tools for transforming functions with a specific side-effect mechanism into pure functions. The names of the transformations in this module are inspired by the Sow/Reap mechanism in Mathematica.

The harvest module exposes two main functions: sow and harvest. sow is used to tag values and harvest can inject values into functions or pull out tagged values.

harvest is a very general purpose transformation purely focused on converting functions that have special side-effects (defined using sow) and "functionalizing" them. Specifically, a function f :: (x: X) -> Y has a set of defined intermediates, or Sows. This set can be divided into intermediates you are "collecting" and intermediates you are "injecting", or Reaps and Plants respectively. Functionalizing f now gives you harvest(f) :: (plants: Plants, x: X) -> Tuple[Y, Reaps]. Generally, most users will not need to use harvest directly, but will use wrappers around it.

sow

sow is the function used to tag values in a function. It takes in a single positional argument, value, which is returned as an output, so sow outside of a tracing context behaves like the identity function, i.e. sow(x, ...) == x. It also takes in two mandatory keyword arguments, tag and name. tag is a string used to namespace intermediate values in a function. For example, some intermediates may be useful for probabilistic programming (samples), and others may be useful to logging (summaries). The tag enables harvest to interact with only one set of intermediates at a time. The name is a string that describes the value you are sow-ing. Eventually, when calling harvest on a function, the name is used as the identifier for the intermediate value.

Finally, sow takes in an optional string keyword argument mode, which is by default set to 'strict'. The mode of a sow describes how it behaves when the same name appears multiple times. In "strict" mode, sow will error if the same (tag, name) appears more than once. Another option is 'append', in which all sows of the same name will be appended into a growing array. Finally, there is 'clobber', where only the final sown value for a given (tag, name) will be returned. The final optional argument for sow is key, which will automatically be tied-in to the output of sow to introduce a fake data-dependence. By default, it is None.

harvest

harvest is a function transformation that augments the behaviors of sows in the function body. Recall, that by default, sows act as identity functions and do not affect the semantics of a function. Harvesting f produces a function that can take advantage of sows present in its execution. harvest is a function that takes in a function f and a string tag. harvest will only interact with sows whose tag matches the input tag. The returned function can interact with the sows in the function body in either of two ways. The first is via "injection", where intermediate values in the function values can be overridden. harvest(f) takes in an additional initial argument, plants, a dictionary mapping names to values. Each name in plants should correspond to a sow in f, and while running harvest(f) rather than using the value at runtime for the sow, we substitute in the value from the plants dictionary. The other way in which harvest(f) interacts with sows is that if it encounters a sow whose tag matches and whose name is not in plants, it will add the output of the sow to a dictionary mapping the sow name to its output, called reaps. The reaps dictionary, at the end of harvest(f)'s execution, will contain the outputs of all sows whose values were not injected, or "planted."

The general convention is that, for any given execution of harvest(f, tag=tag), there will be no more remaining sows of the given tag if the function were to be reharvested, i.e. if we were to nest harvests with the same tag harvest(harvest(f, tag='some_tag'), tag='some_tag'), the outer harvest would have nothing to plant or to reap.

Examples:

Using sow and harvest

def f(x):
  y = sow(x + 1., tag='intermediate', name='y')
  return y + 1.

# Injecting, or "planting" a value for `y`.
harvest(f, tag='intermediate')({'y': 0.}, 1.)  # ==> (1., {})
harvest(f, tag='intermediate')({'y': 0.}, 5.)  # ==> (1., {})

# Collecting , or "reaping" the value of `y`.
harvest(f, tag='intermediate')({}, 1.)  # ==> (3., {'y': 2.})
harvest(f, tag='intermediate')({}, 5.)  # ==> (7., {'y': 6.})

Using reap and plant.

reap and plant are simple wrappers around harvest. reap only pulls intermediate values without injecting, and plant only injects values without collecting intermediate values.

def f(x):
  y = sow(x + 1., tag='intermediate', name='y')
  return y + 1.

# Injecting, or "planting" a value for `y`.
plant(f, tag='intermediate')({'y': 0.}, 1.)  # ==> 1.
plant(f, tag='intermediate')({'y': 0.}, 5.)  # ==> 1.

# Collecting , or "reaping" the value of `y`.
reap(f, tag='intermediate')(1.)  # ==> {'y': 2.}
reap(f, tag='intermediate')(5.)  # ==> {'y': 6.}

Sharp edges

  • harvest has undefined semantics under autodifferentiation. If a function you're taking the gradient of has a sow, it might produce unintuitive results when harvested. To better control gradient semantics, you can use jax.custom_jvp or jax.custom_vjp. The current implementation sows primals and tangents in the JVP but ignore cotangents in the VJP. These particular semantics are subject to change.
  • Planting values into a pmap is partially working. Harvest tries to plant all the values, assuming they have a leading map dimension.

Classes

class HarvestTrace: An evaluating trace that dispatches to a dynamic context.

class HarvestTracer: A HarvestTracer just encapsulates a single value.

Functions

call_and_reap(...): Transforms a function into one that additionally returns its sown values.

harvest(...)

nest(...): Wraps a function to create a new scope for harvested values.

plant(...): Transforms a function into one that injects values in place of sown ones.

reap(...): Transforms a function into one that returns its sown values.

sow(...): Marks a value with a name and a tag.