View source on GitHub |
The SNAPER criterion from [1].
tfp.experimental.mcmc.snaper_criterion(
previous_state,
proposed_state,
accept_prob,
trajectory_length,
direction,
state_mean=None,
state_mean_weight=0.0,
validate_args=False,
experimental_shard_axis_names=None,
experimental_reduce_chain_axis_names=None
)
SNAPER stands for Squared Norm Along Principal component ESJD Rate:
SNAPER = E[(((x' - E[x'])^T p)**2 - ((x' - E[x])^T p)**2)**2 /
trajectory_length],
where x
is the previous chain state, x'
is the next chain state, and p
is a unit vector (the direction
argument). Both expectations are with
respect to the chain's stationary distribution. In practice, the inner
expectation is replaced by the empirical mean across chains, so computing this
criterion requires that at least 2 chains are present unless state_mean
and
state_mean_weight
are set. The outer expectation is computed by the caller
(e.g. in the GradientBasedTrajectoryLengthAdaptation
kernel).
This can be thought of as the standard expected squared jump distance (ESJD) criterion, except that the jump distance is computed in the space of squared projections onto a vector.
The direction
vector is typically chosen to be an approximation to the first
principal component of the state covariance matrix.
state_mean
and state_mean_weight
can be used to supplement the empirical
means as follows:
E[x] ≈ (1 - state_mean_weight) * x.mean() + state_mean_weight * state_mean.
Returns | |
---|---|
snaper
|
The value of the SNAPER criterion. |
References
[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>