Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tfp.experimental.substrates.jax.bijectors.FFJORD

View source on GitHub

Implements a continuous normalizing flow X->Y defined via an ODE.

Inherits From: Bijector

tfp.experimental.substrates.jax.bijectors.FFJORD(
    state_time_derivative_fn, ode_solve_fn=None,
    trace_augmentation_fn=trace_jacobian_hutchinson, initial_time=0.0,
    final_time=1.0, validate_args=False, dtype=tf.float32, name='ffjord'
)

This bijector implements a continuous dynamics transformation parameterized by a differential equation, where initial and terminal conditions correspond to domain (X) and image (Y) i.e.

d/dt[state(t)]=state_time_derivative_fn(t, state(t))
state(initial_time) = X
state(final_time) = Y

For this transformation the value of log_det_jacobian follows another differential equation, reducing it to computation of the trace of the jacbian along the trajectory

state_time_derivative = state_time_derivative_fn(t, state(t))
d/dt[log_det_jac(t)] = Tr(jacobian(state_time_derivative, state(t)))

FFJORD constructor takes two functions ode_solve_fn and trace_augmentation_fn arguments that customize integration of the differential equation and trace estimation.

Differential equation integration is performed by a call to ode_solve_fn. Custom ode_solve_fn must accept the following arguments:

  • ode_fn(time, state): Differential equation to be solved.
  • initial_time: Scalar float or floating Tensor representing the initial time.
  • initial_state: Floating Tensor representing the initial state.
  • solution_times: 1D floating Tensor of solution times.

And return a Tensor of shape [solution_times.shape, initial_state.shape] representing state values evaluated at solution_times. In addition ode_solve_fn must support nested structures. For more details see the interface of tfp.math.ode.Solver.solve().

Trace estimation is computed simultaneously with state_time_derivative using augmented_state_time_derivative_fn that is generated by trace_augmentation_fn. trace_augmentation_fn takes state_time_derivative_fn, state.shape and state.dtype arguments and returns a augmented_state_time_derivative_fn callable that computes both state_time_derivative and unreduced trace_estimation.

Custom ode_solve_fn and trace_augmentation_fn examples:

# custom_solver_fn: `callable(f, t_initial, t_solutions, y_initial, ...)`
# custom_solver_kwargs: Additional arguments to pass to custom_solver_fn.
def ode_solve_fn(ode_fn, initial_time, initial_state, solution_times):
  results = custom_solver_fn(ode_fn, initial_time, solution_times,
                             initial_state, **custom_solver_kwargs)
  return results

ffjord = tfb.FFJORD(state_time_derivative_fn, ode_solve_fn=ode_solve_fn)
# state_time_derivative_fn: `callable(time, state)`
# trace_jac_fn: `callable(time, state)` unreduced jacobian trace function

def trace_augmentation_fn(ode_fn, state_shape, state_dtype):
  def augmented_ode_fn(time, state):
    return ode_fn(time, state), trace_jac_fn(time, state)
  return augmented_ode_fn

ffjord = tfb.FFJORD(state_time_derivative_fn,
                    trace_augmentation_fn=trace_augmentation_fn)

For more details on FFJORD and continous normalizing flows see [1], [2].

Usage example:

tfd = tfp.distributions
tfb = tfp.bijectors
# state_time_derivative_fn: `Callable(time, state)` -> state_time_derivative
# e.g. Neural network with inputs and outputs of the same shapes and dtypes.

bijector = tfb.FFJORD(state_time_derivative_fn=state_time_derivative_fn)
y = bijector.forward(x)  # forward mapping
x = bijector.inverse(y)  # inverse mapping
base = tfd.Normal(tf.zeros_like(x), tf.ones_like(x))  # Base distribution
transformed_distribution = tfd.TransformedDistribution(base, bijector)

References

[1]: Chen, T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural ordinary differential equations. In Advances in neural information processing systems (pp. 6571-6583)

[2]: Grathwohl, W., Chen, R. T., Betterncourt, J., Sutskever, I., & Duvenaud, D. (2018). Ffjord: Free-form continuous dynamics for scalable reversible generative models. arXiv preprint arXiv:1810.01367. arxiv.org.abs/1810.01367

Args:

  • state_time_derivative_fn: Python callable taking arguments time (a scalar representing time) and state (a Tensor representing the state at given time) returning the time derivative of the state at given time.
  • ode_solve_fn: Python callable taking arguments ode_fn (same as state_time_derivative_fn above), initial_time (a scalar representing the initial time of integration), initial_state (a Tensor of floating dtype represents the initial state) and solution_times (1D Tensor of floating dtype representing time at which to obtain the solution) returning a Tensor of shape [time_axis, initial_state.shape]. Will take [final_time] as the solution_times argument and state_time_derivative_fn as ode_fn argument. For details on providing custom ode_solve_fn see class docstring. If None a DormandPrince solver from tfp.math.ode is used. Default value: None
  • trace_augmentation_fn: Python callable taking arguments ode_fn ( python callable same as state_time_derivative_fn above), state_shape (TensorShape of a the state), dtype (same as dtype of the state) and returning a python callable taking arguments time (a scalar representing the time at which the function is evaluted), state (a Tensor representing the state at given time) that computes a tuple (ode_fn(time, state), jacobian_trace_estimation). jacobian_trace_estimation should represent trace of the jacobian of ode_fn with respect to state. state_time_derivative_fn will be passed as ode_fn argument. For details on providing custom trace_augmentation_fn see class docstring. Default value: tfp.bijectors.ffjord.trace_jacobian_hutchinson
  • initial_time: Scalar float representing time to which the x value of the bijector corresponds to. Passed as initial_time to ode_solve_fn. For default solver can be Python float or floating scalar Tensor. Default value: 0.
  • final_time: Scalar float representing time to which the y value of the bijector corresponds to. Passed as solution_times to ode_solve_fn. For default solver can be Python float or floating scalar Tensor. Default value: 1.
  • validate_args: Python 'bool' indicating whether to validate input. Default value: False
  • dtype: tf.DType to prefer when converting args to Tensors. Else, we fall back to a common dtype inferred from the args, finally falling back to float32.
  • name: Python str name prefixed to Ops created by this function.

Attributes:

  • dtype: dtype of Tensors transformable by this distribution.
  • forward_min_event_ndims: Returns the minimal number of dimensions bijector.forward operates on.
  • graph_parents: Returns this Bijector's graph_parents as a Python list.
  • inverse_min_event_ndims: Returns the minimal number of dimensions bijector.inverse operates on.
  • is_constant_jacobian: Returns true iff the Jacobian matrix is not a function of x.

  • name: Returns the string name of this Bijector.

  • trainable_variables

  • validate_args: Returns True if Tensor arguments will be validated.

  • variables

Methods

__call__

View source

__call__(
    value, name=None, **kwargs
)

Applies or composes the Bijector, depending on input type.

This is a convenience function which applies the Bijector instance in three different ways, depending on the input:

  1. If the input is a tfd.Distribution instance, return tfd.TransformedDistribution(distribution=input, bijector=self).
  2. If the input is a tfb.Bijector instance, return tfb.Chain([self, input]).
  3. Otherwise, return self.forward(input)

Args:

  • value: A tfd.Distribution, tfb.Bijector, or a Tensor.
  • name: Python str name given to ops created by this function.
  • **kwargs: Additional keyword arguments passed into the created tfd.TransformedDistribution, tfb.Bijector, or self.forward.

Returns:

  • composition: A tfd.TransformedDistribution if the input was a tfd.Distribution, a tfb.Chain if the input was a tfb.Bijector, or a Tensor computed by self.forward.

Examples

sigmoid = tfb.Reciprocal()(
    tfb.AffineScalar(shift=1.)(
      tfb.Exp()(
        tfb.AffineScalar(scale=-1.))))
# ==> `tfb.Chain([
#         tfb.Reciprocal(),
#         tfb.AffineScalar(shift=1.),
#         tfb.Exp(),
#         tfb.AffineScalar(scale=-1.),
#      ])`  # ie, `tfb.Sigmoid()`

log_normal = tfb.Exp()(tfd.Normal(0, 1))
# ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`

tfb.Exp()([-1., 0., 1.])
# ==> tf.exp([-1., 0., 1.])

forward

View source

forward(
    x, name='forward', **kwargs
)

Returns the forward Bijector evaluation, i.e., X = g(Y).

Args:

  • x: Tensor. The input to the 'forward' evaluation.
  • name: The name to give this op.
  • **kwargs: Named arguments forwarded to subclass implementation.

Returns:

Tensor.

Raises:

  • TypeError: if self.dtype is specified and x.dtype is not self.dtype.
  • NotImplementedError: if _forward is not implemented.

forward_event_shape

View source

forward_event_shape(
    input_shape
)

Shape of a single sample from a single batch as a TensorShape.

Same meaning as forward_event_shape_tensor. May be only partially defined.

Args:

  • input_shape: TensorShape indicating event-portion shape passed into forward function.

Returns:

  • forward_event_shape_tensor: TensorShape indicating event-portion shape after applying forward. Possibly unknown.

forward_event_shape_tensor

View source

forward_event_shape_tensor(
    input_shape, name='forward_event_shape_tensor'
)

Shape of a single sample from a single batch as an int32 1D Tensor.

Args:

  • input_shape: Tensor, int32 vector indicating event-portion shape passed into forward function.
  • name: name to give to the op

Returns:

  • forward_event_shape_tensor: Tensor, int32 vector indicating event-portion shape after applying forward.

forward_log_det_jacobian

View source

forward_log_det_jacobian(
    x, event_ndims, name='forward_log_det_jacobian', **kwargs
)

Returns both the forward_log_det_jacobian.

Args:

  • x: Tensor. The input to the 'forward' Jacobian determinant evaluation.
  • event_ndims: Number of dimensions in the probabilistic events being transformed. Must be greater than or equal to self.forward_min_event_ndims. The result is summed over the final dimensions to produce a scalar Jacobian determinant for each event, i.e. it has shape rank(x) - event_ndims dimensions.
  • name: The name to give this op.
  • **kwargs: Named arguments forwarded to subclass implementation.

Returns:

Tensor, if this bijector is injective. If not injective this is not implemented.

Raises:

  • TypeError: if self.dtype is specified and y.dtype is not self.dtype.
  • NotImplementedError: if neither _forward_log_det_jacobian nor {_inverse, _inverse_log_det_jacobian} are implemented, or this is a non-injective bijector.

inverse

View source

inverse(
    y, name='inverse', **kwargs
)

Returns the inverse Bijector evaluation, i.e., X = g^{-1}(Y).

Args:

  • y: Tensor. The input to the 'inverse' evaluation.
  • name: The name to give this op.
  • **kwargs: Named arguments forwarded to subclass implementation.

Returns:

Tensor, if this bijector is injective. If not injective, returns the k-tuple containing the unique k points (x1, ..., xk) such that g(xi) = y.

Raises:

  • TypeError: if self.dtype is specified and y.dtype is not self.dtype.
  • NotImplementedError: if _inverse is not implemented.

inverse_event_shape

View source

inverse_event_shape(
    output_shape
)

Shape of a single sample from a single batch as a TensorShape.

Same meaning as inverse_event_shape_tensor. May be only partially defined.

Args:

  • output_shape: TensorShape indicating event-portion shape passed into inverse function.

Returns:

  • inverse_event_shape_tensor: TensorShape indicating event-portion shape after applying inverse. Possibly unknown.

inverse_event_shape_tensor

View source

inverse_event_shape_tensor(
    output_shape, name='inverse_event_shape_tensor'
)

Shape of a single sample from a single batch as an int32 1D Tensor.

Args:

  • output_shape: Tensor, int32 vector indicating event-portion shape passed into inverse function.
  • name: name to give to the op

Returns:

  • inverse_event_shape_tensor: Tensor, int32 vector indicating event-portion shape after applying inverse.

inverse_log_det_jacobian

View source

inverse_log_det_jacobian(
    y, event_ndims, name='inverse_log_det_jacobian', **kwargs
)

Returns the (log o det o Jacobian o inverse)(y).

Mathematically, returns: log(det(dX/dY))(Y). (Recall that: X=g^{-1}(Y).)

Note that forward_log_det_jacobian is the negative of this function, evaluated at g^{-1}(y).

Args:

  • y: Tensor. The input to the 'inverse' Jacobian determinant evaluation.
  • event_ndims: Number of dimensions in the probabilistic events being transformed. Must be greater than or equal to self.inverse_min_event_ndims. The result is summed over the final dimensions to produce a scalar Jacobian determinant for each event, i.e. it has shape rank(y) - event_ndims dimensions.
  • name: The name to give this op.
  • **kwargs: Named arguments forwarded to subclass implementation.

Returns:

  • ildj: Tensor, if this bijector is injective. If not injective, returns the tuple of local log det Jacobians, log(det(Dg_i^{-1}(y))), where g_i is the restriction of g to the ith partition Di.

Raises:

  • TypeError: if self.dtype is specified and y.dtype is not self.dtype.
  • NotImplementedError: if _inverse_log_det_jacobian is not implemented.