View source on GitHub |
Windowed estimates of mean.
tfp.substrates.jax.stats.windowed_mean(
x, low_indices=None, high_indices=None, axis=0, name=None
)
Computes means among data in the Tensor x
along the given windows:
result[i] = mean(x[low_indices[i]:high_indices[i]+1])
efficiently. To wit, if K is the size of low_indices
and
high_indices
, and N
is the size of x
along the given axis
,
the computation takes O(K + N) work, O(log(N)) depth (the length of
the longest series of operations that are performed sequentially),
and only uses O(1) TensorFlow kernel invocations.
This function can be useful for assessing the behavior over time of trailing-window estimators from some iterative process, such as the last half of an MCMC chain.
Suppose x
has shape Bx + [N] + E
, where the Bx
component has
rank axis
, and low_indices
and high_indices
broadcast to shape
[M]
. Then each element of low_indices
and high_indices
must be between 0 and N+1, and the shape of the output will be
Bx + [M] + E
. Batch shape in the indices is not currently supported.
The default windows are
[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...
This corresponds to analyzing x
as though it were streaming, for
example successive states of an MCMC sampler, and we were interested
in the variance of the last half of the data at each point.
Returns | |
---|---|
means
|
A numeric Tensor holding the windowed means of x along
the axis dimension.
|