View source on GitHub |
Decorates a function to enable defining a custom inverse.
oryx.core.custom_inverse(
f
)
A custom_inverse
-decorated function is semantically identical to the
original except when it is inverted with core.inverse
. By default,
core.inverse(custom_inverse(f))
will programmatically invert the body of
f
, but f
has two additional methods that can override that behavior:
def_inverse_unary
and def_inverse_ildj
.
def_inverse_unary
def_inverse_unary
is applicable if f
is a unary function.
def_inverse_unary
takes in an optional f_inv
function, which is a unary
function from the output of f
to the input of f
.
Example:
@custom_inverse
def add_one(x):
return x + 1.
add_one.def_inverse_unary(lambda x: x * 2) # Define silly custom inverse.
inverse(add_one)(2.) # ==> 4.
With a unary f_inv
function, Oryx will automatically compute an inverse
log-det Jacobian using core.ildj(core.inverse(f_inv))
, but a user can
also override the Jacobian term by providing the optional f_ildj
keyword
argument to def_inverse_unary
.
Example:
@custom_inverse
def add_one(x):
return x + 1.
add_one.def_inverse_unary(lambda x: x * 2, f_ildj=lambda x: jnp.ones_like(x))
ildj(add_one)(2.) # ==> 1.
def_inverse_and_ildj
A more general way of defining a custom inverse or ILDJ is to use
def_inverse_and_ildj
, which will enable the user to invert functions with
partially known inputs and outputs. Take an example like
add = lambda x, y: x + y
, which cannot be inverted with just the output,
but can be inverted when just one input is known. def_inverse_and_ildj
takes a single function f_ildj
as an argument. f_ildj
is a function from
invals
(a set of values corresponding to f
's inputs), outvals
(a set
of values corresponding to f
's outputs) and out_ildjs
(a set of inverse
diagonal log-Jacobian values for each of the outvals
). If any are unknown,
they will be None
. f_ildj
should return a tuple
(new_invals, new_inildjs)
which corresponds to known values of the inputs
and any corresponding diagonal Jacobian values (which should be the same shape
as invals
). If these values cannot be computed (e.g. too many values are
None
) the user can raise a NonInvertibleError
which will signal to Oryx to
give up trying to invert the function for this set of values.
Example:
@custom_inverse
def add(x, y):
return x + y
def add_ildj(invals, outvals, out_ildjs):
x, y = invals
z = outvals
z_ildj = outildjs
if x is None and y is None:
raise NonInvertibleError()
if x is None:
return (z - y, y), (z_ildj + jnp.zeros_like(z), jnp.zeros_like(z))
if y is None:
return (x, z - x), (jnp.zeros_like(z), z_ildj + jnp.zeros_like(z))
add.def_inverse_and_ildj(add_ildj)
inverse(partial(add, 1.))(2.) # ==> 1.
inverse(partial(add, 1.))(2.) # ==> 0.
Args | |
---|---|
f
|
a function for which we'd like to define a custom inverse. |
Returns | |
---|---|
A CustomInverse object whose inverse can be overridden with
def_inverse_unary or def_inverse .
|