View source on GitHub |

Perform a scan with an associative binary operation, in parallel.

```
tfp.math.scan_associative(
fn, elems, max_num_levels=48, validate_args=False, name=None
)
```

The associative scan operation computes the cumulative sum, or
all-prefix sum, of a set of
elements under an associative binary operation [1]. For example, using the
ordinary addition operator `fn = lambda a, b: a + b`

, this is equivalent to
the ordinary cumulative sum `tf.math.cumsum`

along axis 0. This method
supports the general case of arbitrary associative binary operations operating
on `Tensor`

s or structures of `Tensor`

s:

```
associative_scan(fn, elems) = tf.stack([
elems[0],
fn(elems[0], elems[1]),
fn(elems[0], fn(elems[1], elems[2])),
...
fn(elems[0], fn(elems[1], fn(..., fn(elems[-2], elems[-1]))),
], axis=0)
```

The associative structure allows the computation to be decomposed
and executed by parallel reduction. Where a naive sequential
implementation would loop over all `N`

elements, this method requires
only a logarithmic number (`2 * ceil(log_2 N)`

) of sequential steps, and
can thus yield substantial performance speedups from hardware-accelerated
vectorization. The total number of invocations of the binary operation
(including those performed in parallel) is
`2 * (N / 2 + N / 4 + ... + 1) = 2N - 2`

--- i.e., approximately twice as many as a naive approach.

[1] Blelloch, Guy E. Prefix sums and their applications Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990.

#### Args:

: Python callable implementing an associative binary operation with signature`fn`

`r = fn(a, b)`

. This must satisfy associativity:`fn(a, fn(b, c)) == fn(fn(a, b), c)`

. The inputs and result are (possibly nested structures of)`Tensor`

(s), matching`elems`

. Each`Tensor`

has a leading batch dimension in place of`elem_length`

; the`fn`

is expected to map over this dimension. The result`r`

has the same shape (and structure) as the two inputs`a`

and`b`

.: A (possibly nested structure of)`elems`

`Tensor`

(s), each with leading dimension`elem_length`

. Note that`elem_length`

determines the number of recursive steps required to perform the scan: if, in graph mode, this is not statically available, then ops will be created to handle any`elem_length`

up to the maximum dimension of a`Tensor`

.: Python`max_num_levels`

`int`

. The size of the first dimension of the tensors in`elems`

must be less than`2**(max_num_levels + 1)`

. The default value is sufficiently large for most needs. Lowering this value can reduce graph-building time when`scan_associative`

is used with inputs of unknown shape. Default value:`48`

.: Python`validate_args`

`bool`

. When`True`

, runtime checks for invalid inputs are performed. This may carry a performance cost. Default value:`False`

.: Python`name`

`str`

name prefixed to ops created by this function.

#### Returns:

: A (possibly nested structure of)`result`

`Tensor`

(s) of the same shape and structure as`elems`

, in which the`k`

th element is the result of recursively applying`fn`

to combine the first`k`

elements of`elems`

. For example, given`elems = [a, b, c, ...]`

, the result would be`[a, fn(a, b), fn(fn(a, b), c), ...]`

.

#### Examples

```
import tensorflow as tf
import tensorflow_probability as tfp
import operator
# Example 1: Partials sums of numbers.
tfp.math.associative_scan(operator.add, tf.range(0, 4))
# ==> [ 0, 1, 3, 6]
# Example 2: Partial products of random matrices.
dist = tfp.distributions.Normal(loc=0., scale=1.)
matrices = dist.sample(sample_shape=[100, 2, 2])
tfp.math.associative_scan(tf.matmul, matrices)
```