Decorates a function to take observation_history
.
tfp.experimental.mcmc.augment_with_observation_history(
observations, history_size, num_transitions_per_observation=1
)
Args |
observations
|
a (structure of) Tensors, each of shape
concat([[num_observation_steps, b1, ..., bN], event_shape]) with
optional batch dimensions b1, ..., bN .
|
history_size
|
integer Tensor number of steps of history to pass.
|
num_transitions_per_observation
|
integer Tensor number of
state transitions between regular observation points. A value of 1
indicates that there is an observation at every timestep,
2 that every other step is observed, and so on. Values greater than 1
may be used with an appropriately-chosen transition function to
approximate continuous-time dynamics. The initial and final steps
(steps 0 and num_timesteps - 1 ) are always observed.
Default value: 1 .
|
Returns |
augment_fn
|
Python callable such that augmented_fn = augment_fn(fn) .
When called, augmented_fn invokes fn
with an additional observation_history keyword arg, whose value is a
Tensor of shape concat([[history_size, b1, ..., bN], event_shape])
containing up to the most recent history_size observations.
|