![]() |
Implements a continuous normalizing flow X->Y defined via an ODE.
Inherits From: Bijector
tfp.bijectors.FFJORD(
state_time_derivative_fn, ode_solve_fn=None,
trace_augmentation_fn=trace_jacobian_hutchinson, initial_time=0.0,
final_time=1.0, validate_args=False, dtype=tf.float32, name='ffjord'
)
This bijector implements a continuous dynamics transformation parameterized by a differential equation, where initial and terminal conditions correspond to domain (X) and image (Y) i.e.
d/dt[state(t)]=state_time_derivative_fn(t, state(t))
state(initial_time) = X
state(final_time) = Y
For this transformation the value of log_det_jacobian
follows another
differential equation, reducing it to computation of the trace of the jacbian
along the trajectory
state_time_derivative = state_time_derivative_fn(t, state(t))
d/dt[log_det_jac(t)] = Tr(jacobian(state_time_derivative, state(t)))
FFJORD constructor takes two functions ode_solve_fn
and
trace_augmentation_fn
arguments that customize integration of the
differential equation and trace estimation.
Differential equation integration is performed by a call to ode_solve_fn
.
Custom ode_solve_fn
must accept the following arguments:
- ode_fn(time, state): Differential equation to be solved.
- initial_time: Scalar float or floating Tensor representing the initial time.
- initial_state: Floating Tensor representing the initial state.
- solution_times: 1D floating Tensor of solution times.
And return a Tensor of shape [solution_times.shape, initial_state.shape]
representing state values evaluated at solution_times
. In addition
ode_solve_fn
must support nested structures. For more details see the
interface of tfp.math.ode.Solver.solve()
.
Trace estimation is computed simultaneously with state_time_derivative
using augmented_state_time_derivative_fn
that is generated by
trace_augmentation_fn
. trace_augmentation_fn
takes
state_time_derivative_fn
, state.shape
and state.dtype
arguments and
returns a augmented_state_time_derivative_fn
callable that computes both
state_time_derivative
and unreduced trace_estimation
.
Custom ode_solve_fn
and trace_augmentation_fn
examples:
# custom_solver_fn: `callable(f, t_initial, t_solutions, y_initial, ...)`
# custom_solver_kwargs: Additional arguments to pass to custom_solver_fn.
def ode_solve_fn(ode_fn, initial_time, initial_state, solution_times):
results = custom_solver_fn(ode_fn, initial_time, solution_times,
initial_state, **custom_solver_kwargs)
return results
ffjord = tfb.FFJORD(state_time_derivative_fn, ode_solve_fn=ode_solve_fn)
# state_time_derivative_fn: `callable(time, state)`
# trace_jac_fn: `callable(time, state)` unreduced jacobian trace function
def trace_augmentation_fn(ode_fn, state_shape, state_dtype):
def augmented_ode_fn(time, state):
return ode_fn(time, state), trace_jac_fn(time, state)
return augmented_ode_fn
ffjord = tfb.FFJORD(state_time_derivative_fn,
trace_augmentation_fn=trace_augmentation_fn)
For more details on FFJORD and continous normalizing flows see [1], [2].
Usage example:
tfd = tfp.distributions
tfb = tfp.bijectors
# state_time_derivative_fn: `Callable(time, state)` -> state_time_derivative
# e.g. Neural network with inputs and outputs of the same shapes and dtypes.
bijector = tfb.FFJORD(state_time_derivative_fn=state_time_derivative_fn)
y = bijector.forward(x) # forward mapping
x = bijector.inverse(y) # inverse mapping
base = tfd.Normal(tf.zeros_like(x), tf.ones_like(x)) # Base distribution
transformed_distribution = tfd.TransformedDistribution(base, bijector)
References
[1]: Chen, T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural ordinary differential equations. In Advances in neural information processing systems (pp. 6571-6583)
[2]: Grathwohl, W., Chen, R. T., Betterncourt, J., Sutskever, I., & Duvenaud, D. (2018). Ffjord: Free-form continuous dynamics for scalable reversible generative models. arXiv preprint arXiv:1810.01367. http://arxiv.org.abs/1810.01367
Args | |
---|---|
state_time_derivative_fn
|
Python callable taking arguments time
(a scalar representing time) and state (a Tensor representing the
state at given time ) returning the time derivative of the state at
given time .
|
ode_solve_fn
|
Python callable taking arguments ode_fn (same as
state_time_derivative_fn above), initial_time (a scalar representing
the initial time of integration), initial_state (a Tensor of floating
dtype represents the initial state) and solution_times (1D Tensor of
floating dtype representing time at which to obtain the solution)
returning a Tensor of shape [time_axis, initial_state.shape]. Will take
[final_time] as the solution_times argument and
state_time_derivative_fn as ode_fn argument. For details on
providing custom ode_solve_fn see class docstring.
If None a DormandPrince solver from tfp.math.ode is used.
Default value: None
|
trace_augmentation_fn
|
Python callable taking arguments ode_fn (
python callable same as state_time_derivative_fn above),
state_shape (TensorShape of a the state), dtype (same as dtype of
the state) and returning a python callable taking arguments time
(a scalar representing the time at which the function is evaluted),
state (a Tensor representing the state at given time ) that computes
a tuple (ode_fn(time, state) , jacobian_trace_estimation ).
jacobian_trace_estimation should represent trace of the jacobian of
ode_fn with respect to state . state_time_derivative_fn will be
passed as ode_fn argument. For details on providing custom
trace_augmentation_fn see class docstring.
Default value: tfp.bijectors.ffjord.trace_jacobian_hutchinson
|
initial_time
|
Scalar float representing time to which the x value of the
bijector corresponds to. Passed as initial_time to ode_solve_fn .
For default solver can be Python float or floating scalar Tensor .
Default value: 0.
|
final_time
|
Scalar float representing time to which the y value of the
bijector corresponds to. Passed as solution_times to ode_solve_fn .
For default solver can be Python float or floating scalar Tensor .
Default value: 1.
|
validate_args
|
Python 'bool' indicating whether to validate input. Default value: False |
dtype
|
tf.DType to prefer when converting args to Tensor s. Else, we
fall back to a common dtype inferred from the args, finally falling
back to float32.
|
name
|
Python str name prefixed to Ops created by this function.
|
Attributes | |
---|---|
dtype
|
|
forward_min_event_ndims
|
Returns the minimal number of dimensions bijector.forward operates on.
Multipart bijectors return structured |
graph_parents
|
Returns this Bijector 's graph_parents as a Python list.
|
has_static_min_event_ndims
|
Returns True if the bijector has statically-known min_event_ndims .
|
inverse_min_event_ndims
|
Returns the minimal number of dimensions bijector.inverse operates on.
Multipart bijectors return structured |
is_constant_jacobian
|
Returns true iff the Jacobian matrix is not a function of x. |
name
|
Returns the string name of this Bijector .
|
name_scope
|
Returns a tf.name_scope instance for this class.
|
parameters
|
Dictionary of parameters used to instantiate this Bijector .
|
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
validate_args
|
Returns True if Tensor arguments will be validated. |
variables
|
Sequence of variables owned by this module and its submodules. |
Methods
forward
forward(
x, name='forward', **kwargs
)
Returns the forward Bijector
evaluation, i.e., X = g(Y).
Args | |
---|---|
x
|
Tensor (structure). The input to the 'forward' evaluation.
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure).
|
Raises | |
---|---|
TypeError
|
if self.dtype is specified and x.dtype is not
self.dtype .
|
NotImplementedError
|
if _forward is not implemented.
|
forward_dtype
forward_dtype(
dtype=UNSPECIFIED, name='forward_dtype', **kwargs
)
Returns the dtype returned by forward
for the provided input.
forward_event_ndims
forward_event_ndims(
event_ndims, **kwargs
)
Returns the number of event dimensions produced by forward
.
forward_event_shape
forward_event_shape(
input_shape
)
Shape of a single sample from a single batch as a TensorShape
.
Same meaning as forward_event_shape_tensor
. May be only partially defined.
Args | |
---|---|
input_shape
|
TensorShape (structure) indicating event-portion shape
passed into forward function.
|
Returns | |
---|---|
forward_event_shape_tensor
|
TensorShape (structure) indicating
event-portion shape after applying forward . Possibly unknown.
|
forward_event_shape_tensor
forward_event_shape_tensor(
input_shape, name='forward_event_shape_tensor'
)
Shape of a single sample from a single batch as an int32
1D Tensor
.
Args | |
---|---|
input_shape
|
Tensor , int32 vector (structure) indicating event-portion
shape passed into forward function.
|
name
|
name to give to the op |
Returns | |
---|---|
forward_event_shape_tensor
|
Tensor , int32 vector (structure)
indicating event-portion shape after applying forward .
|
forward_log_det_jacobian
forward_log_det_jacobian(
x, event_ndims, name='forward_log_det_jacobian', **kwargs
)
Returns both the forward_log_det_jacobian.
Args | |
---|---|
x
|
Tensor (structure). The input to the 'forward' Jacobian determinant
evaluation.
|
event_ndims
|
Number of dimensions in the probabilistic events being
transformed. Must be greater than or equal to
self.forward_min_event_ndims . The result is summed over the final
dimensions to produce a scalar Jacobian determinant for each event, i.e.
it has shape rank(x) - event_ndims dimensions.
Multipart bijectors require structured event_ndims, such that
rank(y[i]) - rank(event_ndims[i]) is the same for all elements i of
the structured input. Furthermore, the first event_ndims[i] of each
x[i].shape must be the same for all i (broadcasting is not allowed).
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure), if this bijector is injective.
If not injective this is not implemented.
|
Raises | |
---|---|
TypeError
|
if y 's dtype is incompatible with the expected output dtype.
|
NotImplementedError
|
if neither _forward_log_det_jacobian
nor {_inverse , _inverse_log_det_jacobian } are implemented, or
this is a non-injective bijector.
|
inverse
inverse(
y, name='inverse', **kwargs
)
Returns the inverse Bijector
evaluation, i.e., X = g^{-1}(Y).
Args | |
---|---|
y
|
Tensor (structure). The input to the 'inverse' evaluation.
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure), if this bijector is injective.
If not injective, returns the k-tuple containing the unique
k points (x1, ..., xk) such that g(xi) = y .
|
Raises | |
---|---|
TypeError
|
if y 's structured dtype is incompatible with the expected
output dtype.
|
NotImplementedError
|
if _inverse is not implemented.
|
inverse_dtype
inverse_dtype(
dtype=UNSPECIFIED, name='inverse_dtype', **kwargs
)
Returns the dtype returned by inverse
for the provided input.
inverse_event_ndims
inverse_event_ndims(
event_ndims, **kwargs
)
Returns the number of event dimensions produced by inverse
.
inverse_event_shape
inverse_event_shape(
output_shape
)
Shape of a single sample from a single batch as a TensorShape
.
Same meaning as inverse_event_shape_tensor
. May be only partially defined.
Args | |
---|---|
output_shape
|
TensorShape (structure) indicating event-portion shape
passed into inverse function.
|
Returns | |
---|---|
inverse_event_shape_tensor
|
TensorShape (structure) indicating
event-portion shape after applying inverse . Possibly unknown.
|
inverse_event_shape_tensor
inverse_event_shape_tensor(
output_shape, name='inverse_event_shape_tensor'
)
Shape of a single sample from a single batch as an int32
1D Tensor
.
Args | |
---|---|
output_shape
|
Tensor , int32 vector (structure) indicating
event-portion shape passed into inverse function.
|
name
|
name to give to the op |
Returns | |
---|---|
inverse_event_shape_tensor
|
Tensor , int32 vector (structure)
indicating event-portion shape after applying inverse .
|
inverse_log_det_jacobian
inverse_log_det_jacobian(
y, event_ndims, name='inverse_log_det_jacobian', **kwargs
)
Returns the (log o det o Jacobian o inverse)(y).
Mathematically, returns: log(det(dX/dY))(Y)
. (Recall that: X=g^{-1}(Y)
.)
Note that forward_log_det_jacobian
is the negative of this function,
evaluated at g^{-1}(y)
.
Args | |
---|---|
y
|
Tensor (structure). The input to the 'inverse' Jacobian determinant
evaluation.
|
event_ndims
|
Number of dimensions in the probabilistic events being
transformed. Must be greater than or equal to
self.inverse_min_event_ndims . The result is summed over the final
dimensions to produce a scalar Jacobian determinant for each event, i.e.
it has shape rank(y) - event_ndims dimensions.
Multipart bijectors require structured event_ndims, such that
rank(y[i]) - rank(event_ndims[i]) is the same for all elements i of
the structured input. Furthermore, the first event_ndims[i] of each
x[i].shape must be the same for all i (broadcasting is not allowed).
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
ildj
|
Tensor , if this bijector is injective.
If not injective, returns the tuple of local log det
Jacobians, log(det(Dg_i^{-1}(y))) , where g_i is the restriction
of g to the ith partition Di .
|
Raises | |
---|---|
TypeError
|
if x 's dtype is incompatible with the expected inverse-dtype.
|
NotImplementedError
|
if _inverse_log_det_jacobian is not implemented.
|
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose
names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method
|
The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |
__call__
__call__(
value, name=None, **kwargs
)
Applies or composes the Bijector
, depending on input type.
This is a convenience function which applies the Bijector
instance in
three different ways, depending on the input:
- If the input is a
tfd.Distribution
instance, returntfd.TransformedDistribution(distribution=input, bijector=self)
. - If the input is a
tfb.Bijector
instance, returntfb.Chain([self, input])
. - Otherwise, return
self.forward(input)
Args | |
---|---|
value
|
A tfd.Distribution , tfb.Bijector , or a (structure of) Tensor .
|
name
|
Python str name given to ops created by this function.
|
**kwargs
|
Additional keyword arguments passed into the created
tfd.TransformedDistribution , tfb.Bijector , or self.forward .
|
Returns | |
---|---|
composition
|
A tfd.TransformedDistribution if the input was a
tfd.Distribution , a tfb.Chain if the input was a tfb.Bijector , or
a (structure of) Tensor computed by self.forward .
|
Examples
sigmoid = tfb.Reciprocal()(
tfb.Shift(shift=1.)(
tfb.Exp()(
tfb.Scale(scale=-1.))))
# ==> `tfb.Chain([
# tfb.Reciprocal(),
# tfb.Shift(shift=1.),
# tfb.Exp(),
# tfb.Scale(scale=-1.),
# ])` # ie, `tfb.Sigmoid()`
log_normal = tfb.Exp()(tfd.Normal(0, 1))
# ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`
tfb.Exp()([-1., 0., 1.])
# ==> tf.exp([-1., 0., 1.])
__eq__
__eq__(
other
)
Return self==value.