![]() |
Perform a scan with an associative binary operation, in parallel.
tfp.substrates.jax.math.scan_associative(
fn, elems, max_num_levels=48, validate_args=False, name=None
)
The associative scan operation computes the cumulative sum, or
all-prefix sum, of a set of
elements under an associative binary operation [1]. For example, using the
ordinary addition operator fn = lambda a, b: a + b
, this is equivalent to
the ordinary cumulative sum tf.math.cumsum
along axis 0. This method
supports the general case of arbitrary associative binary operations operating
on Tensor
s or structures of Tensor
s:
associative_scan(fn, elems) = tf.stack([
elems[0],
fn(elems[0], elems[1]),
fn(elems[0], fn(elems[1], elems[2])),
...
fn(elems[0], fn(elems[1], fn(..., fn(elems[-2], elems[-1]))),
], axis=0)
The associative structure allows the computation to be decomposed
and executed by parallel reduction. Where a naive sequential
implementation would loop over all N
elements, this method requires
only a logarithmic number (2 * ceil(log_2 N)
) of sequential steps, and
can thus yield substantial performance speedups from hardware-accelerated
vectorization. The total number of invocations of the binary operation
(including those performed in parallel) is
2 * (N / 2 + N / 4 + ... + 1) = 2N - 2
--- i.e., approximately twice as many as a naive approach.
[1] Blelloch, Guy E. Prefix sums and their applications Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990.
Args | |
---|---|
fn
|
Python callable implementing an associative binary operation with
signature r = fn(a, b) . This must satisfy associativity:
fn(a, fn(b, c)) == fn(fn(a, b), c) . The inputs and result are
(possibly nested structures of) Tensor (s), matching elems . Each
Tensor has a leading batch dimension in place of elem_length ; the fn
is expected to map over this dimension. The result r has the same shape
(and structure) as the two inputs a and b .
|
elems
|
A (possibly nested structure of) Tensor (s), each with leading
dimension elem_length . Note that elem_length determines the number
of recursive steps required to perform the scan: if, in graph mode,
this is not statically available, then ops will be created to
handle any elem_length up to the maximum dimension of a Tensor .
|
max_num_levels
|
Python int . The size
of the first dimension of the tensors in elems must be less than
2**(max_num_levels + 1) . The default value is sufficiently large
for most needs. Lowering this value can reduce graph-building time when
scan_associative is used with inputs of unknown shape.
Default value: 48 .
|
validate_args
|
Python bool . When True , runtime checks
for invalid inputs are performed. This may carry a performance cost.
Default value: False .
|
name
|
Python str name prefixed to ops created by this function.
|
Returns | |
---|---|
result
|
A (possibly nested structure of) Tensor (s) of the same shape
and structure as elems , in which the k th element is the result of
recursively applying fn to combine the first k elements of
elems . For example, given elems = [a, b, c, ...] , the result
would be [a, fn(a, b), fn(fn(a, b), c), ...] .
|
Examples
from tensorflow_probability.python.internal.backend import numpy as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
import operator
# Example 1: Partials sums of numbers.
tfp.math.associative_scan(operator.add, tf.range(0, 4))
# ==> [ 0, 1, 3, 6]
# Example 2: Partial products of random matrices.
dist = tfp.distributions.Normal(loc=0., scale=1.)
matrices = dist.sample(sample_shape=[100, 2, 2])
tfp.math.associative_scan(tf.matmul, matrices)