View source on GitHub |
Perform a scan with an associative binary operation, in parallel.
tfp.substrates.jax.math.scan_associative(
fn, elems, max_num_levels=48, axis=0, validate_args=False, name=None
)
The associative scan operation computes the cumulative sum, or
all-prefix sum, of a set of
elements under an associative binary operation [1]. For example, using the
ordinary addition operator fn = lambda a, b: a + b
, this is equivalent to
the ordinary cumulative sum tf.math.cumsum
along axis 0. This method
supports the general case of arbitrary associative binary operations operating
on Tensor
s or structures of Tensor
s:
scan_associative(fn, elems) = tf.stack([
elems[0],
fn(elems[0], elems[1]),
fn(elems[0], fn(elems[1], elems[2])),
...
fn(elems[0], fn(elems[1], fn(..., fn(elems[-2], elems[-1]))),
], axis=0)
The associative structure allows the computation to be decomposed
and executed by parallel reduction. Where a naive sequential
implementation would loop over all N
elements, this method requires
only a logarithmic number (2 * ceil(log_2 N)
) of sequential steps, and
can thus yield substantial performance speedups from hardware-accelerated
vectorization. The total number of invocations of the binary operation
(including those performed in parallel) is
2 * (N / 2 + N / 4 + ... + 1) = 2N - 2
--- i.e., approximately twice as many as a naive approach.
[1] Blelloch, Guy E. Prefix sums and their applications Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990.
Examples
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
import operator
# Example 1: Partials sums of numbers.
tfp.math.scan_associative(operator.add, tf.range(0, 4))
# ==> [ 0, 1, 3, 6]
# Example 2: Partial products of random matrices.
dist = tfp.distributions.Normal(loc=0., scale=1.)
matrices = dist.sample(sample_shape=[100, 2, 2])
tfp.math.scan_associative(tf.matmul, matrices)