Count how often x
falls in intervals defined by edges
.
tfp.substrates.jax.stats.histogram(
x,
edges,
axis=None,
weights=None,
extend_lower_interval=False,
extend_upper_interval=False,
dtype=None,
name=None
)
Given edges = [c0, ..., cK]
, defining intervals
I0 = [c0, c1)
, I1 = [c1, c2)
, ..., I_{K-1} = [c_{K-1}, cK]
,
This function counts how often x
falls into each interval.
Values of x
outside of the intervals cause errors. Consider using
extend_lower_interval
, extend_upper_interval
to deal with this.
Args |
x
|
Numeric N-D Tensor with N > 0 . If axis is not
None , must have statically known number of dimensions. The
axis kwarg determines which dimensions index iid samples.
Other dimensions of x index "events" for which we will compute different
histograms.
|
edges
|
Tensor of same dtype as x . The first dimension indexes edges
of intervals. Must either be 1-D or have edges.shape[1:] the same
as the dimensions of x excluding axis .
If rank(edges) > 1 , edges[k] designates a shape edges.shape[1:]
Tensor of interval edges for the corresponding dimensions of x .
|
axis
|
Optional 0-D or 1-D integer Tensor with constant
values. The axis in x that index iid samples.
Default value: None (treat every dimension as sample dimension).
|
weights
|
Optional Tensor of same dtype and shape as x .
For each value in x , the bin will be incremented by
the corresponding weight instead of 1.
|
extend_lower_interval
|
Python bool . If True , extend the lowest
interval I0 to (-inf, c1] .
|
extend_upper_interval
|
Python bool . If True , extend the upper
interval I_{K-1} to [c_{K-1}, +inf) .
|
dtype
|
The output type (int32 or int64 ). Default value: x.dtype .
|
name
|
A Python string name to prepend to created ops.
Default value: 'histogram'
|
Returns |
counts
|
Tensor of type dtype and, with
~axis = [i for i in range(arr.ndim) if i not in axis] ,
counts.shape = [edges.shape[0]] + x.shape[~axis] .
With I a multi-index into ~axis , counts[k][I] is the number of times
event(s) fell into the kth interval of edges or with weights
non-None the sum of the weight(s) corresponding to the event(s) in a bin.
|
Raises |
ValueError
|
if the shape of x and weights are not the same.
|
Examples
# x.shape = [1000, 2]
# x[:, 0] ~ Uniform(0, 1), x[:, 1] ~ Uniform(1, 2).
x = tf.stack([tf.random.stateless_uniform([1000]), 1 + tf.random.stateless_uniform([1000])],
axis=-1)
# edges ==> bins [0, 0.5), [0.5, 1.0), [1.0, 1.5), [1.5, 2.0].
edges = [0., 0.5, 1.0, 1.5, 2.0]
tfp.stats.histogram(x, edges)
==> approximately [500, 500, 500, 500]
tfp.stats.histogram(x, edges, axis=0)
==> approximately [[500, 500, 0, 0], [0, 0, 500, 500]]