View source on GitHub |
Counts the number of occurrences of each value in an integer array arr
.
tfp.substrates.jax.stats.count_integers(
arr,
weights=None,
minlength=None,
maxlength=None,
axis=None,
dtype=tf.int32,
name=None
)
Works like tf.math.bincount
, but provides an axis
kwarg that specifies
dimensions to reduce over. With
~axis = [i for i in range(arr.ndim) if i not in axis]
,
this function returns a Tensor
of shape [K] + arr.shape[~axis]
.
If minlength
and maxlength
are not given, K = tf.reduce_max(arr) + 1
if arr
is non-empty, and 0 otherwise.
If weights
are non-None, then index i
of the output stores the sum of the
value in weights
at each index where the corresponding value in arr
is
i
.
Returns | |
---|---|
A vector with the same dtype as weights or the given dtype . The bin
values.
|