Python callable to wrap, having signature
new_state_dist = fn(step, state_with_history, **kwargs) where
state_with_history is a StateWithHistory namedtuple.
Python callable wrapping fn, having signature
new_state_with_history_dist = augmented_fn(step, state_with_history,
**kwargs). The return value is a tfd.JointDistributionNamed instance
overtfp.experimental.mcmc.StateWithHistory namedtuples, in which the
state_history component is rotated to discard
the (previously-oldest) state at the initial position and append the
new state at the final position.