tfp.experimental.sequential.extended_kalman_filter

Applies an Extended Kalman Filter to observed data.

The Extended Kalman Filter is a nonlinear version of the Kalman filter, in which the transition function is linearized by first-order Taylor expansion around the current mean and covariance of the state estimate.

observations a (structure of) Tensors, each of shape concat([[num_timesteps, b1, ..., bN], [event_size]]) with scalar event_size and optional batch dimensions b1, ..., bN.
initial_state_prior a tfd.Distribution instance (typically MultivariateNormal) with event_shape equal to state_size and an optional batch_shape of [b1, ..., bN], representing the prior over the state.
transition_fn a Python callable that accepts (batched) vectors of length state_size, and returns a tfd.Distribution instance, typically a MultivariateNormal, representing the state transition and covariance.
observation_fn a Python callable that accepts a (batched) vector of length state_size and returns a tfd.Distribution instance, typically a MultivariateNormal representing the observation model and covariance.
transition_jacobian_fn a Python callable that accepts a (batched) vector of length state_size and returns a (batched) matrix of shape [state_size, state_size], representing the Jacobian of transition_fn.
observation_jacobian_fn a Python callable that accepts a (batched) vector of length state_size and returns a (batched) matrix of size [state_size, event_size], representing the Jacobian of observation_fn.
name Python str name for ops created by this method. Default value: None (i.e., 'extended_kalman_filter').

filtered_mean a (structure of) Tensor(s) of shape concat([[num_timesteps, b1, ..., bN], [state_size]]). The mean of the filtered state estimate.
filtered_cov a (structure of) Tensor(s) of shape concat([[num_timesteps, b1, ..., bN], [state_size, state_size]]). The covariance of the filtered state estimate.
predicted_mean a (structure of) Tensor(s) of shape concat([[num_timesteps, b1, ..., bN], [state_size]]). The prior predicted means of the state.
predicted_cov a (structure of) Tensor(s) of shape concat([[num_timesteps, b1, ..., bN], [state_size, state_size]]) The prior predicted covariances of the state estimate.
observation_mean a (structure of) Tensor(s) of shape concat([[num_timesteps, b1, ..., bN], [event_size]]). The prior predicted mean of observations.
observation_cov a (structure of) Tensor(s) of shape concat([[num_timesteps, b1, ..., bN], [event_size, event_size]]). The prior predicted covariance of observations.
log_marginal_likelihood a (structure of) Tensor(s) of shape [num_timesteps, b1, ..., bN]. Log likelihood of the observations with respect to the observation.
timestep a (structure of) integer Tensor(s) of shape [num_timesteps, b1, ..., bN] containing time indices.

Examples

Estimate a simple nonlinear system: Let's consider a system defined by the transition equation y_{t+1} = y_t - 0.1 * w_t **3 and w_{t+1} = w_t, such that the state can be expressed as [y, w]. The transition_fn and transition_jacobian_fn can be expressed as:

def transition_fn(x):
  return tfd.MultivariateNormalDiag(
      tf.stack(
          [x[..., 0] - 0.1 * x[..., 1]**3, x[..., 1]], axis=-1),
      scale_diag=[0.7, 0.2])

def transition_jacobian_fn(x):
  return tf.reshape(
    tf.stack(
        [1. - 0.1 * x[..., 1]**3, -0.3 * x[..., 1]**2,
        tf.zeros(x.shape[:-1]), tf.ones(x.shape[:-1])], axis=-1),
    [2, 2])

Assume we take noisy measurements of only the first element of the state.

observation_fn = lambda x: tfd.MultivariateNormalDiag(
    x[..., :1], scale_diag=[1.])
observation_jacobian_fn = lambda x: [[1., 0.]]

We define a prior over the initial state, and use it to synthesize data for 20 steps of the process.

initial_state_prior = tfd.MultivariateNormalDiag(0., scale_diag=[1., 0.3])

x = [np.zeros((2,), dtype=np.float32)]
for t in range(20):
  x.append(transition_fn(x[-1]).sample())
x = tf.stack(x)

observations=observation_fn(x).sample()

Run the Extended Kalman filter on the synthesized observed data.

results = tfp.experimental.sequential.extended_kalman_filter(
    observations,
    initial_state_prior,
    transition_fn,
    observation_fn,
    transition_jacobian_fn,
    observation_jacobian_fn)