|View source on GitHub|
Augments a prior or proposal distribution's state space with history.
tfp.experimental.mcmc.augment_prior_with_state_history( prior, history_size )
The augmented state space is over
namedtuples, which contain the original
state as well as a
state_history is a structure of
state, of shape
concat([[num_particles, history_size], state.shape[1:]]). In other words,
previous states for each particle are indexed along
axis=1, to the right
of the particle indices.
a (joint) distribution over the initial latent state,
with optional batch shape
As a toy example, let's see how we'd use state history to experiment with stochastic 'Fibonacci sequences'. We'll assume that the sequence starts at a value sampled from a Poisson distribution.
initial_state_prior = tfd.Poisson(5.) initial_state_with_history_prior = ( tfp.experimental.mcmc.augment_prior_with_state_history( initial_state_prior, history_size=2))
Note that we've augmented the state space to include a state history of
size two. The augmented state space is over instances of
tfp.experimental.mcmc.StateWithHistory. Initially, the state history
will simply tile the initial state: if
s = initial_state_with_history_prior.sample(), then
Next, we'll define a
transition_fn that uses the history to
sample the next integer in the sequence, also from a Poisson distribution.
@tfp.experimental.mcmc.augment_with_state_history def fibonacci_transition_fn(_, state_with_history): expected_next_element = tf.reduce_sum( state_with_history.state_history[:, -2:], axis=1) return tfd.Poisson(rate=expected_next_element)
Our transition function must accept
so that it can access the history, but it returns a distribution
only over the next state. Decorating it with
ensures that the state history is automatically propagated.
Combined with an observation function (which must also now be defined on the
StateWithHistory space), we can track stochastic Fibonacci
sequences and, for example, infer the initial value of a sequence:
def observation_fn(_, state_with_history): return tfd.Poisson(rate=state_with_history.state) trajectories, _ = tfp.experimental.mcmc.infer_trajectories( observations=tf.convert_to_tensor([4., 11., 16., 23., 40., 69., 100.]), initial_state_prior=initial_state_with_history_prior, transition_fn=fibonacci_transition_fn, observation_fn=observation_fn, num_particles=1024) inferred_initial_states = trajectories.state print(tf.unique_with_counts(inferred_initial_states))