TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfp.mcmc.random_walk_normal_fn

View source on GitHub

Returns a callable that adds a random normal perturbation to the input.

tfp.mcmc.random_walk_normal_fn(
    scale=1.0,
    name=None
)

This function returns a callable that accepts a Python list of Tensors of any shapes and dtypes representing the state parts of the current_state and a random seed. The supplied argument scale must be a Tensor or Python list of Tensors representing the scale of the generated proposal. scale must broadcast with the state parts of current_state. The callable adds a sample from a zero-mean normal distribution with the supplied scales to each state part and returns a same-type list of Tensors as the state parts of current_state.

Args:

  • scale: a Tensor or Python list of Tensors of any shapes and dtypes controlling the scale of the normal proposal distribution.
  • name: Python str name prefixed to Ops created by this function. Default value: 'random_walk_normal_fn'.

Returns:

  • random_walk_normal_fn: A callable accepting a Python list of Tensors representing the state parts of the current_state and an int representing the random seed to be used to generate the proposal. The callable returns the same-type list of Tensors as the input and represents the proposal for the RWM algorithm.