View source on GitHub |
Runs a Markov chain defined by the given TransitionKernel
.
tfp.experimental.mcmc.sample_chain(
kernel,
num_results,
current_state,
previous_kernel_results=None,
reducer=(),
previous_reducer_state=None,
trace_fn=_trace_everything,
parallel_iterations=10,
seed=None,
name=None
)
This is meant as a (more) helpful frontend to the low-level
TransitionKernel
-based MCMC API, supporting several main features:
- Running a batch of multiple independent chains using SIMD parallelism
- Tracing the history of the chains, or not tracing it to save memory
- Computing reductions over chain history, whether it is also traced or not
- Warm (re-)start, including auxiliary state
This function samples from a Markov chain at current_state
whose
stationary distribution is governed by the supplied TransitionKernel
instance (kernel
).
The current_state
can be represented as a single Tensor
or a list
of
Tensors
which collectively represent the current state.
This function can sample from multiple chains, in parallel. Whether or not
there are multiple chains is dictated by how the kernel
treats its inputs.
Typically, the shape of the independent chains is shape of the result of the
target_log_prob_fn
used by the kernel
when applied to the given
current_state
.
This function can compute reductions over the samples in tandem with sampling,
for example to return summary statistics without materializing all the
samples. To request reductions, pass a Reducer
object, or a nested
structure of Reducer
objects, as the reducer=
argument.
In addition to the chain state, this function supports tracing of auxiliary
variables used by the kernel, as well as intermediate values of any supplied
reductions. The traced values are selected by specifying trace_fn
. The
trace_fn
must be a callable accepting three arguments: the chain state, the
kernel_results of the kernel
, and the current results of the reductions, if
any are supplied. The return value of trace_fn
(which may be a Tensor
or
a nested structure of Tensor
s) is accumulated, such that each Tensor
gains
a new outmost dimension representing time in the chain history.
Since MCMC states are correlated, it is sometimes desirable to produce
additional intermediate states, and then discard them, ending up with a set of
states with decreased autocorrelation. See [Owen (2017)][1]. Such 'thinning'
is made possible by setting num_steps_between_results > 0
. The chain then
takes num_steps_between_results
extra steps between the steps that make it
into the results, or are shown to any supplied reductions. The extra steps
are never materialized, and thus do not increase memory requirements.
Args | |
---|---|
kernel
|
An instance of tfp.mcmc.TransitionKernel which implements one step
of the Markov chain.
|
num_results
|
Integer number of (non-discarded) Markov chain draws to compute. |
current_state
|
Tensor or Python list of Tensor s representing the
initial state(s) of the Markov chain(s).
|
previous_kernel_results
|
A Tensor or a nested collection of Tensor s
representing internal calculations made within the previous call to this
function (or as returned by bootstrap_results ).
|
reducer
|
A (possibly nested) structure of Reducer s to be evaluated
on the kernel 's samples. If no reducers are given (reducer=None ),
their states will not be passed to any supplied trace_fn .
|
previous_reducer_state
|
A (possibly nested) structure of running states
corresponding to the structure in reducer . For resuming streaming
reduction computations begun in a previous run.
|
trace_fn
|
A callable that takes in the current chain state, the current
auxiliary kernel state, and the current result of any reducers, and
returns a Tensor or a nested collection of Tensor s that is then
traced. If None , nothing is traced.
|
parallel_iterations
|
The number of iterations allowed to run in parallel. It
must be a positive integer. See tf.while_loop for more details.
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'mcmc_sample_chain').
|