tfp.experimental.stats.RunningPotentialScaleReduction

Holds metadata for and computes a running R-hat diagnostic statistic.

RunningPotentialScaleReduction uses Gelman and Rubin (1992)'s potential scale reduction (also known as R-hat) for chain convergence [1].

If multiple independent R-hat computations are desired across a latent state, one should use a (possibly nested) collection for initialization parameters independent_chain_ndims and shape. Subsequent chain states used to update the streaming R-hat should mimic their identical structure.

RunningPotentialScaleReduction also assumes that incoming samples have shape [Ci1, Ci2,...,CiD] + A. Dimensions 0 through D - 1 index the Ci1 x ... x CiD independent chains to be tested for convergence to the same target. The remaining dimensions, A, represent the event shape and hence, can have any shape (even empty, which implies scalar samples). The number of independent chain dimensions is defined by the independent_chain_ndims parameter at initialization.

RunningPotentialScaleReduction objects do not hold state information. That information, which includes intermediate calculations, are held in a RunningPotentialScaleReductionState as returned via initialize and update method calls.

RunningPotentialScaleReduction is meant to serve general streaming R-hat. For a specialized version that fits streaming over MCMC samples, see RhatReducer in tfp.experimental.mcmc.

References

[1]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation Using Multiple Sequences. Statistical Science, 7(4):457-472, 1992.

shape Python Tuple or TensorShape representing the shape of incoming samples. Using a collection implies that future samples will mimic that exact structure.
independent_chain_ndims Integer or Integer type Tensor with value >= 1 giving the number of leading dimensions holding independent chain results to be tested for convergence. Using a collection implies that future samples will mimic that exact structure.
dtype Dtype of incoming samples and the resulting statistics. By default, the dtype is tf.float32. Any integer dtypes will be cast to corresponding floats (i.e. tf.int32 will be cast to tf.float32), as intermediate calculations should be performing floating-point division.

Methods

finalize

View source

Finalizes potential scale reduction computation for the state.

Args
state RunningPotentialScaleReductionState that represents the current state of running statistics.

Returns
rhat An estimate of the R-hat.

initialize

View source

Initializes an empty RunningPotentialScaleReductionState.

Returns
state RunningPotentialScaleReductionState representing a stream of no inputs.

update

View source

Update the RunningPotentialScaleReductionState with a new sample.

Args
state RunningPotentialScaleReductionState that represents the current state of running statistics.
new_sample Incoming Tensor sample or (possibly nested) collection of Tensors with shape and dtype compatible with those used to form the RunningPotentialScaleReductionState.

Returns
state RunningPotentialScaleReductionState with updated calculations.