tfp.substrates.jax.mcmc.default_swap_proposal_fn

Make the default swap proposal func, with P[swap], for replica swap MC.

With probability prob_swap, propose combinations of replicas to swap When exchanging, create combinations of adjacent replicas in Replica Exchange Monte Carlo. See also review paper [1].

swap_fn = default_swap_proposal_fn(prob_swap=0.5)

swap_fn(num_replica=3)
==> [1, 0, 2]  # 1 swap, 0 <--> 1

swap_fn(num_replica=3)
==> [0, 1, 2]  # 0 swaps

swap_fn(num_replica=3, batch_shape=[2])
==> [[0, 1],
     [2, 0],
     [1, 2]]

prob_swap Scalar Tensor in [0, 1] giving probability that any swaps will be generated.
name Python str name given to ops created by this function. Default value: 'adjacent_swaps'.

default_swap_proposal_fn_ Python callable which take a number of replicas (a Python integer), and integer Tensor batch_shape, an unused step_count, a seed, and returns swaps, a shape [num_replica] + batch_shape Tensor, where axis 0 indexes "one-time swaps", i.e., such that (if rank(swaps) == 1, range(num_replicas) == tf.gather(swaps, swaps).

References

[1]: David J. Earl, Michael W. Deem Parallel Tempering: Theory, Applications, and New Perspectives https://arxiv.org/abs/physics/0508111