Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

ChEES rate criterion.

This is just like chees_criterion, but normalized by the trajectory length:

ChEES rate = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2 /

previous_state (Possibly nested) floating point Tensor. The previous state of the HMC chain.
proposed_state (Possibly nested) floating point Tensor. The proposed state of the HMC chain.
accept_prob Floating Tensor. Probability of acceping the proposed state.
trajectory_length Floating Tensor. Trajectory length.
validate_args Whether to perform non-static argument validation.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
experimental_reduce_chain_axis_names A string or list of string names indicating which named chain axes to reduce over when computing the criterion.

chees_rate The value of the ChEES rate criterion.