Attend the Women in ML Symposium on December 7

# tfp.experimental.mcmc.snaper_criterion

Stay organized with collections Save and categorize content based on your preferences.

The SNAPER criterion from [1].

SNAPER stands for Squared Norm Along Principal component ESJD Rate:

``````SNAPER = E[(((x' - E[x'])^T p)**2 - ((x' - E[x])^T p)**2)**2 /
trajectory_length],
``````

where `x` is the previous chain state, `x'` is the next chain state, and `p` is a unit vector (the `direction` argument). Both expectations are with respect to the chain's stationary distribution. In practice, the inner expectation is replaced by the empirical mean across chains, so computing this criterion requires that at least 2 chains are present unless `state_mean` and `state_mean_weight` are set. The outer expectation is computed by the caller (e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel).

This can be thought of as the standard expected squared jump distance (ESJD) criterion, except that the jump distance is computed in the space of squared projections onto a vector.

The `direction` vector is typically chosen to be an approximation to the first principal component of the state covariance matrix.

`state_mean` and `state_mean_weight` can be used to supplement the empirical means as follows:

``````E[x] ≈ (1 - state_mean_weight) * x.mean() + state_mean_weight * state_mean.
``````

`previous_state` (Possibly nested) floating point `Tensor`. The previous state of the HMC chain.
`proposed_state` (Possibly nested) floating point `Tensor`. The proposed state of the HMC chain.
`accept_prob` Floating `Tensor`. Probability of acceping the proposed state.
`trajectory_length` Floating `Tensor`. Mean trajectory length (not used in this criterion).
`direction` (Possibly nested) floating point `Tensor`. A unit vector onto which the centered state should be projected before computing ESJD. Typically this chosen to be an approximation to the first principal component of the state covariance matrix.
`state_mean` Optional (Possibly nested) floating point `Tensor`. The estimated state mean.
`state_mean_weight` Floating point `Tensor`. The weight of the `state_mean`.
`validate_args` Whether to perform non-static argument validation.
`experimental_shard_axis_names` A structure of string names indicating how members of the state are sharded.
`experimental_reduce_chain_axis_names` A string or list of string names indicating which named chain axes to reduce over when computing the criterion.

`snaper` The value of the SNAPER criterion.

#### References

[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>

[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"Missing the information I need" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"Too complicated / too many steps" },{ "type": "thumb-down", "id": "outOfDate", "label":"Out of date" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"Samples / code issue" },{ "type": "thumb-down", "id": "otherDown", "label":"Other" }]
[{ "type": "thumb-up", "id": "easyToUnderstand", "label":"Easy to understand" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"Solved my problem" },{ "type": "thumb-up", "id": "otherUp", "label":"Other" }]