Returns a callable that adds a random uniform perturbation to the input.
tfp.substrates.jax.mcmc.random_walk_uniform_fn(
scale=1.0, name=None
)
For more details on random_walk_uniform_fn
, see
random_walk_normal_fn
. scale
might
be a Tensor
or a list of Tensor
s that should broadcast with state parts
of the current_state
. The generated uniform perturbation is sampled as a
uniform point on the rectangle [-scale, scale]
.
Args |
scale
|
a Tensor or Python list of Tensor s of any shapes and dtypes
controlling the upper and lower bound of the uniform proposal
distribution.
|
name
|
Python str name prefixed to Ops created by this function.
Default value: 'random_walk_uniform_fn'.
|
Returns |
random_walk_uniform_fn
|
A callable accepting a Python list of Tensor s
representing the state parts of the current_state and an int
representing the random seed used to generate the proposal. The callable
returns the same-type list of Tensor s as the input and represents the
proposal for the RWM algorithm.
|