tfp.substrates.jax.mcmc.even_odd_swap_proposal_fn

Make a deterministic swap proposal function, alternating even/odd swaps.

This proposal function swaps deterministically swap_frequency fraction of the time, alternating even and odd parity. This was shown in [2] to mix better than random schemes.

Contrast this with default_swap_proposal_fn, which swaps randomly with probability prob_swap.

swap_fn = even_odd_swap_proposal_fn(swap_frequency=1)

even_odd_swap_proposal_fn(num_replica=4, step_count=0)
==> [1, 0, 3, 2]  # Swap 0 <--> 1 and 2 <--> 3, even parity.

even_odd_swap_proposal_fn(num_replica=4, step_count=1)
==> [0, 2, 1, 3]  # Swap 1 <--> 2, odd parity.

swap_frequency Scalar Tensor in [0, 1] giving the frequency of swaps. Swaps will occur, with alternating parity, every N steps, where N = 1 / swap_frequency.
name Python str name given to ops created by this function. Default value: 'even_odd_swaps'.

default_swap_proposal_fn_ Python callable which take a number of replicas (a Python integer), and integer Tensor batch_shape, a step_count, a seed, and returns swaps, a shape [num_replica] + batch_shape Tensor, where axis 0 indexes "one-time swaps", i.e., such that (if rank(swaps) == 1, range(num_replicas) == tf.gather(swaps, swaps).

References

[1]: S. Syed, A. Bouchard-Cote G. Deligiannidis, A. Doucet Non-Reversible Parallel Tempering: a Scalable Highly Parallel MCMC Scheme https://arxiv.org/abs/1905.02939