tfp.substrates.numpy.random.sanitize_seed

Map various PRNG seed flavors to a seed Tensor.

This function implements TFP's standard PRNG seeding semantics. See https://github.com/tensorflow/probability/blob/main/PRNGS.md for details.

Operationally, sanitize_seed maps any seed flavor to a "stateless-compatible" seed. Under TensorFlow and NumPy this means:

  • If the seed argument is an int or None, we use tf.random.stateless_uniform to statefully draw a pair of unbounded int32s and wrap them into a Tensor.
  • If the seed argument is a stateless-compatible seed already, we just cast it to an int32[2] Tensor.

Under JAX, this function only accepts outputs from jax.random.PRNGKey, being a no-op except for the salting behavior described below.

This, any function that accepts a seed argument can be written in stateless-seed style internally, and acquires TFP's seed-type-directed stateless/stateful switching behavior by just running the input seed through sanitize_seed on entry.

The sanitize_seed function also allows salting the seed: if a user accidentally passes the same stateful seed to two different calls to sanitize_seed with different salts, they will get independent randomness. We may micro-optimize by removing salting from sanitize_seed of already-stateless seeds in the future, as using a stateless seed already requires seed uniqueness discipline.

seed An int32[2] Tensor or a Python list or tuple of 2 ints, which will be treated as stateless seeds; or a Python int or None, which will be treated as stateful seeds.
salt An optional Python string.
name An optional Python string, name to add to TF ops created by this function.

seed An int32[2] Tensor suitable for use as a stateless PRNG seed.