tfp.substrates.jax.stats.histogram

Count how often x falls in intervals defined by edges.

Given edges = [c0, ..., cK], defining intervals I0 = [c0, c1), I1 = [c1, c2), ..., I_{K-1} = [c_{K-1}, cK], This function counts how often x falls into each interval.

Values of x outside of the intervals cause errors. Consider using extend_lower_interval, extend_upper_interval to deal with this.

x Numeric N-D Tensor with N > 0. If axis is not None, must have statically known number of dimensions. The axis kwarg determines which dimensions index iid samples. Other dimensions of x index "events" for which we will compute different histograms.
edges Tensor of same dtype as x. The first dimension indexes edges of intervals. Must either be 1-D or have edges.shape[1:] the same as the dimensions of x excluding axis. If rank(edges) > 1, edges[k] designates a shape edges.shape[1:] Tensor of interval edges for the corresponding dimensions of x.
axis Optional 0-D or 1-D integer Tensor with constant values. The axis in x that index iid samples. Default value: None (treat every dimension as sample dimension).
weights Optional Tensor of same dtype and shape as x. For each value in x, the bin will be incremented by the corresponding weight instead of 1.
extend_lower_interval Python bool. If True, extend the lowest interval I0 to (-inf, c1].
extend_upper_interval Python bool. If True, extend the upper interval I_{K-1} to [c_{K-1}, +inf).
dtype The output type (int32 or int64). Default value: x.dtype.
name A Python string name to prepend to created ops. Default value: 'histogram'

counts Tensor of type dtype and, with ~axis = [i for i in range(arr.ndim) if i not in axis], counts.shape = [edges.shape[0]] + x.shape[~axis]. With I a multi-index into ~axis, counts[k][I] is the number of times event(s) fell into the kth interval of edges or with weights non-None the sum of the weight(s) corresponding to the event(s) in a bin.

ValueError if the shape of x and weights are not the same.

Examples

# x.shape = [1000, 2]
# x[:, 0] ~ Uniform(0, 1), x[:, 1] ~ Uniform(1, 2).
x = tf.stack([tf.random.stateless_uniform([1000]), 1 + tf.random.stateless_uniform([1000])],
             axis=-1)

# edges ==> bins [0, 0.5), [0.5, 1.0), [1.0, 1.5), [1.5, 2.0].
edges = [0., 0.5, 1.0, 1.5, 2.0]

tfp.stats.histogram(x, edges)
==> approximately [500, 500, 500, 500]

tfp.stats.histogram(x, edges, axis=0)
==> approximately [[500, 500, 0, 0], [0, 0, 500, 500]]