RSVP for your your local TensorFlow Everywhere event today!

tfp.experimental.mcmc.TracingReducer

Reducer that accumulates trace results at each sample.

Inherits From: Reducer

Trace results are defined by an appropriate trace_fn, which accepts the current chain state and kernel results, and returns the desired result. At each sample, the traced values are added to a TensorArray which accumulates all results. By default, all kernel results are traced but in the future the default will be changed to no results being traced, so plan accordingly.

If wrapped in a tfp.experimental.mcmc.WithReductions Transition Kernel, TracingReducer will not accumulate the kernel results of WithReductions. Rather, the top level kernel results will be that of WithReductions' inner kernel.

As with all reducers, TracingReducer does not hold state information; rather, it stores supplied metadata. Intermediate calculations are held in a TracingState named tuple, which is returned via initialize and one_step method calls.

trace_fn A callable that takes in the current chain state and the previous kernel results and return a Tensor or a nested collection of Tensors that is accumulated across samples.
size Integer or scalar Tensor denoting the size of the accumulated TensorArray. If this is None (which is the default), a dynamic-shaped TensorArray will be used.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'tracing_reducer').

name

parameters

size

trace_fn

Methods

finalize

View source

Finalizes tracing by stacking the accumulated TensorArray.

Args
final_reducer_state TracingState that holds all desired traced results.

Returns
trace Tensor that represents stacked tracing results.

initialize

View source

Initializes a TracingState using previously defined metadata.

Both the initial_chain_state and initial_kernel_results do not count as a sample, and hence, will not be traced. This is a deliberate decision that ensures consistency across sampling procedures.

Args
initial_chain_state A (possibly nested) structure of Tensors or Python lists of Tensors representing the current state(s) of the Markov chain(s). It is used to infer the structure of future trace results.
initial_kernel_results A (possibly nested) structure of Tensors representing internal calculations made in a related TransitionKernel. It is used to infer the structure of future trace results.

Returns
state TracingState with an empty TensorArray in its trace_state field.

one_step

View source

Update the current_reducer_state with a new trace result.

The trace result will be computed by evaluating the trace_fn provided at instantiation with the new_chain_state and previous_kernel_results.

Args
new_chain_state A (possibly nested) structure of incoming chain state(s) with shape and dtype compatible with those used to initialize the TracingState.
current_reducer_state TracingStates representing all previously traced results.
previous_kernel_results A (possibly nested) structure of Tensors representing internal calculations made in a related TransitionKernel.

Returns
new_reducer_state TracingState with updated trace. Its trace_state field holds a TensorArray that includes the newly computed trace result.