View source on GitHub |
Reducer
that accumulates trace results at each sample.
Inherits From: Reducer
tfp.experimental.mcmc.TracingReducer(
trace_fn=_trace_state_and_kernel_results, size=None, name=None
)
Trace results are defined by an appropriate trace_fn
, which accepts the
current chain state and kernel results, and returns the desired result.
At each sample, the traced values are added to a TensorArray
which
accumulates all results. By default, all kernel results are traced but in the
future the default will be changed to no results being traced, so plan
accordingly.
If wrapped in a tfp.experimental.mcmc.WithReductions
Transition Kernel,
TracingReducer will not accumulate the kernel results of WithReductions
.
Rather, the top level kernel results will be that of WithReductions
' inner
kernel.
As with all reducers, TracingReducer
does not hold state information;
rather, it stores supplied metadata. Intermediate calculations are held in
a TracingState
named tuple, which is returned via initialize
and
one_step
method calls.
Attributes | |
---|---|
name
|
|
parameters
|
|
size
|
|
trace_fn
|
Methods
finalize
finalize(
final_reducer_state
)
Finalizes tracing by stacking the accumulated TensorArray
.
Args | |
---|---|
final_reducer_state
|
TracingState that holds all desired traced results.
|
Returns | |
---|---|
trace
|
Tensor that represents stacked tracing results.
|
initialize
initialize(
initial_chain_state, initial_kernel_results=None
)
Initializes a TracingState
using previously defined metadata.
Both the initial_chain_state
and initial_kernel_results
do not count
as a sample, and hence, will not be traced. This is a deliberate decision
that ensures consistency across sampling procedures.
Args | |
---|---|
initial_chain_state
|
A (possibly nested) structure of Tensor s or Python
list s of Tensor s representing the current state(s) of the Markov
chain(s). It is used to infer the structure of future trace results.
|
initial_kernel_results
|
A (possibly nested) structure of Tensor s
representing internal calculations made in a related TransitionKernel .
It is used to infer the structure of future trace results.
|
Returns | |
---|---|
state
|
TracingState with an empty TensorArray in its trace_state
field.
|
one_step
one_step(
new_chain_state, current_reducer_state, previous_kernel_results
)
Update the current_reducer_state
with a new trace result.
The trace result will be computed by evaluating the trace_fn
provided
at instantiation with the new_chain_state
and previous_kernel_results
.
Args | |
---|---|
new_chain_state
|
A (possibly nested) structure of incoming chain state(s)
with shape and dtype compatible with those used to initialize the
TracingState .
|
current_reducer_state
|
TracingState s representing all previously traced
results.
|
previous_kernel_results
|
A (possibly nested) structure of Tensor s
representing internal calculations made in a related
TransitionKernel .
|
Returns | |
---|---|
new_reducer_state
|
TracingState with updated trace. Its trace_state
field holds a TensorArray that includes the newly computed trace
result.
|