View source on GitHub |
Returns a callable that adds a random normal perturbation to the input.
tfp.substrates.jax.mcmc.random_walk_normal_fn(
scale=1.0, name=None
)
This function returns a callable that accepts a Python list
of Tensor
s of
any shapes and dtypes
representing the state parts of the current_state
and a random seed. The supplied argument scale
must be a Tensor
or Python
list
of Tensor
s representing the scale of the generated
proposal. scale
must broadcast with the state parts of current_state
.
The callable adds a sample from a zero-mean normal distribution with the
supplied scales to each state part and returns a same-type list
of Tensor
s
as the state parts of current_state
.