View source on GitHub 
Applies an Extended Kalman Filter to observed data.
tfp.experimental.sequential.extended_kalman_filter(
observations, initial_state_prior, transition_fn, observation_fn,
transition_jacobian_fn, observation_jacobian_fn, name=None
)
The Extended Kalman Filter is a nonlinear version of the Kalman filter, in which the transition function is linearized by firstorder Taylor expansion around the current mean and covariance of the state estimate.
Args  

observations

a (structure of) Tensor s, each of shape
concat([[num_timesteps, b1, ..., bN], [event_size]]) with scalar
event_size and optional batch dimensions b1, ..., bN .

initial_state_prior

a tfd.Distribution instance (typically
MultivariateNormal ) with event_shape equal to state_size and an
optional batch_shape of [b1, ..., bN ], representing the prior over the
state.

transition_fn

a Python callable that accepts (batched) vectors of length
state_size , and returns a tfd.Distribution instance, typically a
MultivariateNormal , representing the state transition and covariance.

observation_fn

a Python callable that accepts a (batched) vector of
length state_size and returns a tfd.Distribution instance, typically
a MultivariateNormal representing the observation model and covariance.

transition_jacobian_fn

a Python callable that accepts a (batched) vector
of length state_size and returns a (batched) matrix of shape
[state_size, state_size] , representing the Jacobian of transition_fn .

observation_jacobian_fn

a Python callable that accepts a (batched) vector
of length state_size and returns a (batched) matrix of size
[state_size, event_size] , representing the Jacobian of observation_fn .

name

Python str name for ops created by this method.
Default value: None (i.e., 'extended_kalman_filter' ).

Returns  

filtered_mean

a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [state_size]]) . The mean of the
filtered state estimate.

filtered_cov

a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [state_size, state_size]]) .
The covariance of the filtered state estimate.

predicted_mean

a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [state_size]]) . The prior
predicted means of the state.

predicted_cov

a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [state_size, state_size]])
The prior predicted covariances of the state estimate.

observation_mean

a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [event_size]]) . The prior
predicted mean of observations.

observation_cov

a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [event_size, event_size]]) . The
prior predicted covariance of observations.

log_marginal_likelihood

a (structure of) Tensor (s) of shape
[num_timesteps, b1, ..., bN] . Log likelihood of the observations with
respect to the observation.

timestep

a (structure of) integer Tensor (s) of shape
[num_timesteps, b1, ..., bN] containing time indices.

Examples
Estimate a simple nonlinear system: Let's consider a system defined by
the transition equation y_{t+1} = y_t  0.1 * w_t **3
and w_{t+1} = w_t
,
such that the state can be expressed as [y, w]
. The transition_fn
and
transition_jacobian_fn
can be expressed as:
def transition_fn(x):
return tfd.MultivariateNormalDiag(
tf.stack(
[x[..., 0]  0.1 * x[..., 1]**3, x[..., 1]], axis=1),
scale_diag=[0.7, 0.2])
def transition_jacobian_fn(x):
return tf.reshape(
tf.stack(
[1.  0.1 * x[..., 1]**3, 0.3 * x[..., 1]**2,
tf.zeros(x.shape[:1]), tf.ones(x.shape[:1])], axis=1),
[2, 2])
Assume we take noisy measurements of only the first element of the state.
observation_fn = lambda x: tfd.MultivariateNormalDiag(
x[..., :1], scale_diag=[1.])
observation_jacobian_fn = lambda x: [[1., 0.]]
We define a prior over the initial state, and use it to synthesize data for 20 steps of the process.
initial_state_prior = tfd.MultivariateNormalDiag(0., scale_diag=[1., 0.3])
x = [np.zeros((2,), dtype=np.float32)]
for t in range(20):
x.append(transition_fn(x[1]).sample())
x = tf.stack(x)
observations=observation_fn(x).sample()
Run the Extended Kalman filter on the synthesized observed data.
results = tfp.experimental.sequential.extended_kalman_filter(
observations,
initial_state_prior,
transition_fn,
observation_fn,
transition_jacobian_fn,
observation_jacobian_fn)