View source on GitHub |
Augments a prior or proposal distribution's state space with history.
tfp.experimental.mcmc.augment_prior_with_state_history(
prior, history_size
)
The augmented state space is over tfp.experimental.mcmc.StateWithHistory
namedtuples, which contain the original state
as well as a state_history
.
The state_history
is a structure of Tensor
s matching state
, of shape
concat([[num_particles, history_size], state.shape[1:]])
. In other words,
previous states for each particle are indexed along axis=1
, to the right
of the particle indices.
Args | |
---|---|
prior
|
a (joint) distribution over the initial latent state,
with optional batch shape [b1, ..., bN] .
|
history_size
|
integer Tensor number of steps of history to pass.
|
Returns | |
---|---|
augmented_prior
|
a tfd.JointDistributionNamed instance whose samples
are tfp.experimental.mcmc.StateWithHistory namedtuples.
|
Example
As a toy example, let's see how we'd use state history to experiment with stochastic 'Fibonacci sequences'. We'll assume that the sequence starts at a value sampled from a Poisson distribution.
initial_state_prior = tfd.Poisson(5.)
initial_state_with_history_prior = (
tfp.experimental.mcmc.augment_prior_with_state_history(
initial_state_prior, history_size=2))
Note that we've augmented the state space to include a state history of
size two. The augmented state space is over instances of
tfp.experimental.mcmc.StateWithHistory
. Initially, the state history
will simply tile the initial state: if
s = initial_state_with_history_prior.sample()
, then
s.state_history==[s.state, s.state]
.
Next, we'll define a transition_fn
that uses the history to
sample the next integer in the sequence, also from a Poisson distribution.
@tfp.experimental.mcmc.augment_with_state_history
def fibonacci_transition_fn(_, state_with_history):
expected_next_element = tf.reduce_sum(
state_with_history.state_history[:, -2:], axis=1)
return tfd.Poisson(rate=expected_next_element)
Our transition function must accept state_with_history
,
so that it can access the history, but it returns a distribution
only over the next state. Decorating it with augment_with_state_history
ensures that the state history is automatically propagated.
Combined with an observation function (which must also now be defined on the
augmented StateWithHistory
space), we can track stochastic Fibonacci
sequences and, for example, infer the initial value of a sequence:
def observation_fn(_, state_with_history):
return tfd.Poisson(rate=state_with_history.state)
trajectories, _ = tfp.experimental.mcmc.infer_trajectories(
observations=tf.convert_to_tensor([4., 11., 16., 23., 40., 69., 100.]),
initial_state_prior=initial_state_with_history_prior,
transition_fn=fibonacci_transition_fn,
observation_fn=observation_fn,
num_particles=1024)
inferred_initial_states = trajectories.state[0]
print(tf.unique_with_counts(inferred_initial_states))