tfp.substrates.jax.math.scan_associative

Perform a scan with an associative binary operation, in parallel.

The associative scan operation computes the cumulative sum, or all-prefix sum, of a set of elements under an associative binary operation [1]. For example, using the ordinary addition operator fn = lambda a, b: a + b, this is equivalent to the ordinary cumulative sum tf.math.cumsum along axis 0. This method supports the general case of arbitrary associative binary operations operating on Tensors or structures of Tensors:

scan_associative(fn, elems) = tf.stack([
  elems[0],
  fn(elems[0], elems[1]),
  fn(elems[0], fn(elems[1], elems[2])),
  ...
  fn(elems[0], fn(elems[1], fn(..., fn(elems[-2], elems[-1]))),
], axis=0)

The associative structure allows the computation to be decomposed and executed by parallel reduction. Where a naive sequential implementation would loop over all N elements, this method requires only a logarithmic number (2 * ceil(log_2 N)) of sequential steps, and can thus yield substantial performance speedups from hardware-accelerated vectorization. The total number of invocations of the binary operation (including those performed in parallel) is 2 * (N / 2 + N / 4 + ... + 1) = 2N - 2 --- i.e., approximately twice as many as a naive approach.

[1] Blelloch, Guy E. Prefix sums and their applications Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990.

fn Python callable implementing an associative binary operation with signature r = fn(a, b). This must satisfy associativity: fn(a, fn(b, c)) == fn(fn(a, b), c). The inputs and result are (possibly nested structures of) Tensor(s), matching elems. Each Tensor has a batch dimension in place of elem_length; the fn is expected to map over this dimension. The result r has the same shape (and structure) as the two inputs a and b.
elems A (possibly nested structure of) Tensor(s), each with dimension elem_length along axis. Note that elem_length determines the number of recursive steps required to perform the scan: if, in graph mode, this is not statically available, then ops will be created to handle any elem_length up to the maximum dimension of a Tensor.
max_num_levels Python int. The axis of the tensors in elems must have size less than 2**(max_num_levels + 1). The default value is sufficiently large for most needs. Lowering this value can reduce graph-building time when scan_associative is used with inputs of unknown shape. Default value: 48.
axis Tensor int axis along which to perform the scan.
validate_args Python bool. When True, runtime checks for invalid inputs are performed. This may carry a performance cost. Default value: False.
name Python str name prefixed to ops created by this function.

result A (possibly nested structure of) Tensor(s) of the same shape and structure as elems, in which the kth element is the result of recursively applying fn to combine the first k elements of elems. For example, given elems = [a, b, c, ...], the result would be [a, fn(a, b), fn(fn(a, b), c), ...].

Examples

from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
import operator

# Example 1: Partials sums of numbers.

tfp.math.scan_associative(operator.add, tf.range(0, 4))
# ==> [ 0, 1, 3, 6]

# Example 2: Partial products of random matrices.

dist = tfp.distributions.Normal(loc=0., scale=1.)
matrices = dist.sample(sample_shape=[100, 2, 2])
tfp.math.scan_associative(tf.matmul, matrices)