Missed TensorFlow World? Check out the recap. Learn more


View source on GitHub

Class Solver

Base class for an ODE solver.


View source


Initialize self. See help(type(self)) for accurate signature.



View source


Solves an initial value problem.

An initial value problem consists of a system of ODEs and an initial condition:

dy/dt(t) = ode_fn(t, y(t))
y(initial_time) = initial_state

Here, t (also called time) is a scalar float Tensor and y(t) (also called the state at time t) is an N-D float or complex Tensor.


The ODE dy/dt(t) = dot(A, y(t)) is solved below.

t_init, t0, t1 = 0., 0.5, 1.
y_init = tf.constant([1., 1.], dtype=tf.float64)
A = [[-1., -2.], [-3., -4.]]

def ode_fn(t, y):
  return tf.linalg.matvec(A, y)

results = tfp.math.ode.BDF.solve(ode_fn, t_init, y_init,
                                 solution_times=[t0, t1])
y0 = results.states[0]  # == dot(matrix_exp(A * t0), y_init)
y1 = results.states[1]  # == dot(matrix_exp(A * t1), y_init)

Using instead solution_times=tfp.math.ode.ChosenBySolver(final_time=1.) yields the state at various times between t_init and final_time chosen automatically by the solver. In this case, results.states[i] is the state at time results.times[i].


The gradient of the result is computed using the adjoint sensitivity method described in [Chen et al. (2018)][1].

grad = tf.gradients(y1, y0) # == dot(e, J)
# J is the Jacobian of y1 with respect to y0. In this case, J = exp(A * t1).
# e = [1, ..., 1] is the row vector of ones.


[1]: Chen, Tian Qi, et al. "Neural ordinary differential equations." Advances in Neural Information Processing Systems. 2018.


  • ode_fn: Function of the form ode_fn(t, y). The input t is a scalar float Tensor. The input y and output are both Tensors with the same shape and dtype as initial_state.
  • initial_time: Scalar float Tensor specifying the initial time.
  • initial_state: N-D float or complex Tensor specifying the initial state. The dtype of initial_state must be complex for problems with complex-valued states (even if the initial state is real).
  • solution_times: 1-D float Tensor specifying a list of times. The solver stores the computed state at each of these times in the returned Results object. Must satisfy initial_time <= solution_times[0] and solution_times[i] < solution_times[i+1]. Alternatively, the user can pass tfp.math.ode.ChosenBySolver(final_time) where final_time is a scalar float Tensor satisfying initial_time < final_time. Doing so requests that the solver automatically choose suitable times up to and including final_time at which to store the computed state.
  • jacobian_fn: Optional function of the form jacobian_fn(t, y). The input t is a scalar float Tensor. The input y has the same shape and dtype as initial_state. The output is a 2N-D Tensor whose shape is initial_state.shape + initial_state.shape and whose dtype is the same as initial_state. In particular, the (i1, ..., iN, j1, ..., jN)-th entry of jacobian_fn(t, y) is the derivative of the (i1, ..., iN)-th entry of ode_fn(t, y) with respect to the (j1, ..., jN)-th entry of y. If this argument is left unspecified, the solver automatically computes the Jacobian if and when it is needed. Default value: None.
  • jacobian_sparsity: Optional 2N-D boolean Tensor whose shape is initial_state.shape + initial_state.shape specifying the sparsity pattern of the Jacobian. This argument is ignored if jacobian_fn is specified. Default value: None.
  • batch_ndims: Optional nonnegative integer. When specified, the first batch_ndims dimensions of initial_state are batch dimensions. Default value: None.
  • previous_solver_internal_state: Optional solver-specific argument used to warm-start this invocation of solve. Default value: None.


Object of type Results.