View source on GitHub |
Reducer
that computes a running R-hat diagnostic statistic.
Inherits From: Reducer
tfp.experimental.mcmc.PotentialScaleReductionReducer(
independent_chain_ndims=1, name=None
)
PotentialScaleReductionReducer
assumes that all provided chain states
include samples from multiple independent Markov chains, and that all of
these chains are to be included in the same calculation.
PotentialScaleReductionReducer
also assumes that incoming samples have
shape [Ci1, Ci2,...,CiD] + A
. Dimensions 0
through D - 1
index the
Ci1 x ... x CiD
independent chains to be tested for convergence to the
same target. The remaining dimensions, A
, represent the event shape and
hence, can have any shape (even empty, which implies scalar samples). The
number of independent chain dimensions is defined by the
independent_chain_ndims
parameter at initialization.
As with all reducers, PotentialScaleReductionReducer does not hold state
information; rather, it stores supplied metadata. Intermediate calculations
are held in a state object, which is returned via initialize
and one_step
method calls.
PotentialScaleReductionReducer
is meant to fit into the larger Streaming
MCMC framework. RunningPotentialScaleReduction
in tfp.experimental.stats
is better suited for more generic streaming R-hat needs. More precise
algorithmic details can also be found by referencing
RunningPotentialScaleReduction
.
Raises | |
---|---|
ValueError
|
if independent_chain_ndims < 1 . This results in undefined
intermediate variance calculations.
|
Attributes | |
---|---|
independent_chain_ndims
|
|
name
|
|
parameters
|
Methods
finalize
finalize(
final_reducer_state
)
Finalizes R-hat calculation from the final_reducer_state
.
Args | |
---|---|
final_reducer_state
|
PotentialScaleReductionReducerState that
represents the final state of the running R-hat statistic.
|
Returns | |
---|---|
rhat
|
an estimate of the R-hat. |
initialize
initialize(
initial_chain_state, initial_kernel_results=None
)
Initializes an empty PotentialScaleReductionReducerState
.
For calculation purposes, the initial_chain_state
does not count as a
sample. This is a deliberate decision that ensures consistency across
sampling procedures (i.e. tfp.mcmc.sample_chain
follows similar
semantics).
Args | |
---|---|
initial_chain_state
|
A (possibly nested) structure of Tensor s or Python
list s of Tensor s representing the current state(s) of the Markov
chain(s). It is used to infer the shape and dtype of future samples.
|
initial_kernel_results
|
A (possibly nested) structure of Tensor s
representing internal calculations made in a related TransitionKernel .
For streaming R-hat, this argument has no influence on the
computation; hence, it is None by default. However, it's
still accepted to fit the Reducer base class.
|
Returns | |
---|---|
state
|
PotentialScaleReductionReducerState with rhat_state field
representing a stream of no inputs.
|
one_step
one_step(
new_chain_state, current_reducer_state, previous_kernel_results=None
)
Update the current_reducer_state
with a new chain state.
Args | |
---|---|
new_chain_state
|
A (possibly nested) structure of incoming chain state(s)
with shape and dtype compatible with those used to initialize the
current_reducer_state .
|
current_reducer_state
|
PotentialScaleReductionReducerState representing
the current state of the running R-hat statistic.
|
previous_kernel_results
|
A (possibly nested) structure of Tensor s
representing internal calculations made in a related
TransitionKernel . For streaming R-hat, this argument has no
influence on computation; hence, it is None by default. However, it's
still accepted to fit the Reducer base class.
|
Returns | |
---|---|
new_reducer_state
|
PotentialScaleReductionReducerState with updated
running statistics.
|