View source on GitHub |
Outputs deterministic pseudorandom values from a gamma distribution.
tf.random.stateless_gamma(
shape,
seed,
alpha,
beta=None,
dtype=tf.dtypes.float32
,
name=None
)
The generated values follow a gamma distribution with specified concentration
(alpha
) and inverse scale (beta
) parameters.
This is a stateless version of tf.random.gamma
: if run twice with the same
seeds and shapes, it will produce the same pseudorandom numbers. The output is
consistent across multiple runs on the same hardware (and between CPU and
GPU),
but may change between versions of TensorFlow or on non-CPU/GPU hardware.
A slight difference exists in the interpretation of the shape
parameter
between stateless_gamma
and gamma
: in gamma
, the shape
is always
prepended to the shape of the broadcast of alpha
with beta
; whereas in
stateless_gamma
the shape
parameter must always encompass the shapes of
each of alpha
and beta
(which must broadcast together to match the
trailing dimensions of shape
).
The samples are differentiable w.r.t. alpha and beta. The derivatives are computed using the approach described in (Figurnov et al., 2018).
Example:
samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5])
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
# the samples drawn from each distribution
samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5])
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
# represents the 7x5 samples drawn from each of the two distributions
alpha = tf.constant([[1.], [3.], [5.]])
beta = tf.constant([[3., 4.]])
samples = tf.random.stateless_gamma(
[30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)
# samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
with tf.GradientTape() as tape:
tape.watch([alpha, beta])
loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma(
[30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)))
dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta])
# unbiased stochastic derivatives of the loss function
alpha.shape == dloss_dalpha.shape # True
beta.shape == dloss_dbeta.shape # True