tfp.experimental.mcmc.particle_filter

View source on GitHub

Samples a series of particles representing filtered latent states.

The particle filter samples from the sequence of "filtering" distributions p(state[t] | observations[:t]) over latent states: at each point in time, this is the distribution conditioned on all observations up to that time. Because particles may be resampled, a particle at time t may be different from the particle with the same index at time t + 1. To reconstruct trajectories by tracing back through the resampling process, see tfp.mcmc.experimental.reconstruct_trajectories.

Each latent state is a Tensor or nested structure of Tensors, as defined by the initial_state_prior.

The transition_fn and proposal_fn args, if specified, have signature next_state_dist = fn(step, state), where step is an int Tensor index of the current time step (beginning at zero), and state represents the latent state at time step. The return value is a tfd.Distribution instance over the state at time step + 1.

Similarly, the observation_fn has signature observation_dist = observation_fn(step, state), where the return value is a distribution over the value(s) observed at time step.

observations a (structure of) Tensors, each of shape concat([[num_observation_steps, b1, ..., bN], event_shape]) with optional batch dimensions b1, ..., bN.
initial_state_prior a (joint) distribution over the initial latent state, with optional batch shape [b1, ..., bN].
transition_fn callable returning a (joint) distribution over the next latent state.
observation_fn callable returning a (joint) distribution over the current observation.
num_particles int Tensor number of particles.
initial_state_proposal a (joint) distribution over the initial latent state, with optional batch shape [b1, ..., bN]. If None, the initial particles are proposed from the initial_state_prior. Default value: None.
proposal_fn callable returning a (joint) proposal distribution over the next latent state. If None, the dynamics model is used ( proposal_fn == transition_fn). Default value: None.
resample_fn Python callable to generate the indices of resampled particles, given their weights. Generally, one of tfp.experimental.mcmc.resample_independent or tfp.experimental.mcmc.resample_systematic, or any function with the same signature, resampled_indices = f(log_probs, event_size, ' 'sample_shape, seed). Default: tfp.experimental.mcmc.resample_systematic.
resample_criterion_fn optional Python callable with signature do_resample = resample_criterion_fn(log_weights), where log_weights is a float Tensor of shape [b1, ..., bN, num_particles] containing log (unnormalized) weights for all particles at the current step. The return value do_resample determines whether particles are resampled at the current step. In the case resample_criterion_fn==None, particles are resampled at every step. The default behavior resamples particles when the current effective sample size falls below half the total number of particles. Note that the resampling criterion is not used at the final step---there, particles are always resampled, so that we return unweighted values. Default value: tfp.experimental.mcmc.ess_below_threshold.
rejuvenation_kernel_fn optional Python callable with signature transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn) where target_log_prob_fn is a provided callable evaluating p(x[t] | y[t], x[t-1]) at each step t, and transition_kernel should be an instance of tfp.mcmc.TransitionKernel. Default value: None.
num_transitions_per_observation scalar Tensor positive int number of state transitions between regular observation points. A value of 1 indicates that there is an observation at every timestep, 2 that every other step is observed, and so on. Values greater than 1 may be used with an appropriately-chosen transition function to approximate continuous-time dynamics. The initial and final steps (steps 0 and num_timesteps - 1) are always observed. Default value: None.
trace_fn Python callable defining the values to be traced at each step. It takes a ParticleFilterStepResults tuple and returns a structure of Tensors. The default function returns (particles, log_weights, parent_indices, step_log_likelihood).
step_indices_to_trace optional int Tensor listing, in increasing order, the indices of steps at which to record the values traced by trace_fn. If None, the default behavior is to trace at every timestep, equivalent to specifying step_indices_to_trace=tf.range(num_timsteps).
seed Python int seed for random ops.
name Python str name for ops created by this method. Default value: None (i.e., 'particle_filter').

particles a (structure of) Tensor(s) matching the latent state, each of shape concat([[num_timesteps, num_particles, b1, ..., bN], event_shape]), representing (possibly weighted) samples from the series of filtering distributions p(latent_states[t] | observations[:t]).
log_weights float Tensor of shape [num_timesteps, num_particles, b1, ..., bN], such that log_weights[t, :] are the logarithms of normalized importance weights (such that exp(reduce_logsumexp(log_weights), axis=-1) == 1.) of the particles at time t. These may be used in conjunction with particles to compute expectations under the series of filtering distributions.
parent_indices int Tensor of shape [num_timesteps, num_particles, b1, ..., bN], such that parent_indices[t, k] gives the index of the particle at time t - 1 that the kth particle at time t is immediately descended from. See also tfp.experimental.mcmc.reconstruct_trajectories.
incremental_log_marginal_likelihoods float Tensor of shape [num_observation_steps, b1, ..., bN], giving the natural logarithm of an unbiased estimate of p(observations[t] | observations[:t]) at each observed timestep t. Note that (by Jensen's inequality) this is smaller in expectation than the true log p(observations[t] | observations[:t]).