tfp.experimental.mcmc.snaper_criterion

The SNAPER criterion from [1].

SNAPER stands for Squared Norm Along Principal component ESJD Rate:

SNAPER = E[(((x' - E[x'])^T p)**2 - ((x' - E[x])^T p)**2)**2 /
           trajectory_length],

where x is the previous chain state, x' is the next chain state, and p is a unit vector (the direction argument). Both expectations are with respect to the chain's stationary distribution. In practice, the inner expectation is replaced by the empirical mean across chains, so computing this criterion requires that at least 2 chains are present unless state_mean and state_mean_weight are set. The outer expectation is computed by the caller (e.g. in the GradientBasedTrajectoryLengthAdaptation kernel).

This can be thought of as the standard expected squared jump distance (ESJD) criterion, except that the jump distance is computed in the space of squared projections onto a vector.

The direction vector is typically chosen to be an approximation to the first principal component of the state covariance matrix.

state_mean and state_mean_weight can be used to supplement the empirical means as follows:

E[x]  (1 - state_mean_weight) * x.mean() + state_mean_weight * state_mean.

previous_state (Possibly nested) floating point Tensor. The previous state of the HMC chain.
proposed_state (Possibly nested) floating point Tensor. The proposed state of the HMC chain.
accept_prob Floating Tensor. Probability of acceping the proposed state.
trajectory_length Floating Tensor. Mean trajectory length (not used in this criterion).
direction (Possibly nested) floating point Tensor. A unit vector onto which the centered state should be projected before computing ESJD. Typically this chosen to be an approximation to the first principal component of the state covariance matrix.
state_mean Optional (Possibly nested) floating point Tensor. The estimated state mean.
state_mean_weight Floating point Tensor. The weight of the state_mean.
validate_args Whether to perform non-static argument validation.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
experimental_reduce_chain_axis_names A string or list of string names indicating which named chain axes to reduce over when computing the criterion.

snaper The value of the SNAPER criterion.

References

[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>