tfp.experimental.mcmc.augment_prior_with_state_history

View source on GitHub

Augments a prior or proposal distribution's state space with history.

The augmented state space is over tfp.experimental.mcmc.StateWithHistory namedtuples, which contain the original state as well as a state_history. The state_history is a structure of Tensors matching state, of shape concat([[num_particles, history_size], state.shape[1:]]). In other words, previous states for each particle are indexed along axis=1, to the right of the particle indices.

prior a (joint) distribution over the initial latent state, with optional batch shape [b1, ..., bN].
history_size integer Tensor number of steps of history to pass.

augmented_prior a tfd.JointDistributionNamed instance whose samples are tfp.experimental.mcmc.StateWithHistory namedtuples.

Example

As a toy example, let's see how we'd use state history to experiment with stochastic 'Fibonacci sequences'. We'll assume that the sequence starts at a value sampled from a Poisson distribution.

initial_state_prior = tfd.Poisson(5.)
initial_state_with_history_prior = (
  tfp.experimental.mcmc.augment_prior_with_state_history(
    initial_state_prior, history_size=2))

Note that we've augmented the state space to include a state history of size two. The augmented state space is over instances of tfp.experimental.mcmc.StateWithHistory. Initially, the state history will simply tile the initial state: if s = initial_state_with_history_prior.sample(), then s.state_history==[s.state, s.state].

Next, we'll define a transition_fn that uses the history to sample the next integer in the sequence, also from a Poisson distribution.

@tfp.experimental.mcmc.augment_with_state_history
def fibonacci_transition_fn(_, state_with_history):
  expected_next_element = tf.reduce_sum(
    state_with_history.state_history[:, -2:], axis=1)
  return tfd.Poisson(rate=expected_next_element)

Our transition function must accept state_with_history, so that it can access the history, but it returns a distribution only over the next state. Decorating it with augment_with_state_history ensures that the state history is automatically propagated.

Combined with an observation function (which must also now be defined on the augmented StateWithHistory space), we can track stochastic Fibonacci sequences and, for example, infer the initial value of a sequence:


def observation_fn(_, state_with_history):
  return tfd.Poisson(rate=state_with_history.state)

trajectories, _ = tfp.experimental.mcmc.infer_trajectories(
  observations=tf.convert_to_tensor([4., 11., 16., 23., 40., 69., 100.]),
  initial_state_prior=initial_state_with_history_prior,
  transition_fn=fibonacci_transition_fn,
  observation_fn=observation_fn,
  num_particles=1024)
inferred_initial_states = trajectories.state[0]
print(tf.unique_with_counts(inferred_initial_states))