tfp.experimental.stats.RunningPotentialScaleReduction

A running R-hat diagnostic.

Inherits From: AutoCompositeTensor

RunningPotentialScaleReduction uses Gelman and Rubin (1992)'s potential scale reduction (also known as R-hat) for chain convergence [1].

If multiple independent R-hat computations are desired across a latent state, one should use a (possibly nested) collection for initialization parameters independent_chain_ndims and shape. Subsequent chain states used to update the streaming R-hat should mimic their identical structure.

RunningPotentialScaleReduction also assumes that incoming samples have shape [Ci1, Ci2,...,CiD] + A. Dimensions 0 through D - 1 index the Ci1 x ... x CiD independent chains to be tested for convergence to the same target. The remaining dimensions, A, represent the event shape and hence, can have any shape (even empty, which implies scalar samples). The number of independent chain dimensions is defined by the independent_chain_ndims parameter at initialization.

RunningPotentialScaleReduction is meant to serve general streaming R-hat. For a specialized version that fits streaming over MCMC samples, see PotentialScaleReductionReducer in tfp.experimental.mcmc.

References

[1]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation Using Multiple Sequences. Statistical Science, 7(4):457-472, 1992.

chain_variances A RunningVariance or nested structure of RunningVariances, giving the variance estimates for the variables of interest.
independent_chain_ndims A Python int or structure of Python ints parallel to chain_variances giving the number of leading dimensions in chain_variances that index the independent chains over which the potential scale reduction factor should be computed. Must be at least 1.

Methods

from_example

View source

Starts an empty RunningPotentialScaleReduction from metadata.

Args
example A Tensor. The RunningPotentialScaleReduction will accept samples of the same dtype and broadcast-compatible shape as the example.
independent_chain_ndims Integer or Integer type Tensor with value >= 1 giving the number of leading dimensions holding independent chain results to be tested for convergence. Using a collection implies that future samples will mimic that exact structure.

Returns
state RunningPotentialScaleReduction representing a stream of no inputs. Note that by convention, the supplied example is used only for initialization, but not counted as a sample.

from_shape

View source

Starts an empty RunningPotentialScaleReduction from metadata.

Args
shape Python Tuple or TensorShape representing the shape of incoming samples. Using a collection implies that future samples will mimic that exact structure. This is useful to supply if the RunningPotentialScaleReduction will be carried by a tf.while_loop, so that broadcasting does not change the shape across loop iterations.
independent_chain_ndims Integer or Integer type Tensor with value >= 1 giving the number of leading dimensions holding independent chain results to be tested for convergence. Using a collection implies that future samples will mimic that exact structure.
dtype Dtype of incoming samples and the resulting statistics. By default, the dtype is tf.float32. Any integer dtypes will be cast to corresponding floats (i.e. tf.int32 will be cast to tf.float32), as intermediate calculations should be performing floating-point division.

Returns
state RunningPotentialScaleReduction representing a stream of no inputs.

potential_scale_reduction

View source

Computes the potential scale reduction for samples accumulated so far.

Returns
rhat An estimate of the R-hat.

tree_flatten

View source

tree_unflatten

View source

update

View source

Update the RunningPotentialScaleReduction with a new sample.

Args
new_sample Incoming Tensor sample or (possibly nested) collection of Tensors with shape and dtype compatible with those used to form the RunningPotentialScaleReduction.

Returns
state RunningPotentialScaleReduction updated to include the new sample.