Builds a trainable blockwise tf.linalg.LinearOperator.

Used in the notebooks

Used in the tutorials

This function returns a trainable blockwise LinearOperator. If operators is a flat list, it is interpreted as blocks along the diagonal of the structure and an instance of tf.linalg.LinearOperatorBlockDiag is returned. If operators is a doubly nested list, then a tf.linalg.LinearOperatorBlockLowerTriangular instance is returned, with the block in row i column j (i >= j) given by operators[i][j]. The operators list may contain LinearOperator instances, LinearOperator subclasses, or callables defining custom constructors (see example below). The dimensions of the blocks are given by block_dims; this argument may be omitted if operators contains only LinearOperator instances.

operators A list or tuple containing LinearOperator subclasses, LinearOperator instances, and/or callables returning (init_fn, apply_fn) pairs. If the list is flat, a tf.linalg.LinearOperatorBlockDiag instance is returned. Otherwise, the list must be singly nested, with the first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a tf.linalg.LinearOperatorBlockLowerTriangular instance is returned. Callables contained in the lists must take two arguments -- shape, the shape of the parameter instantiating the LinearOperator, and dtype, the tf.dtype of the LinearOperator -- and return a further pair of callables representing a stateless trainable operator (see example below).
block_dims List or tuple of integers, representing the sizes of the blocks along one dimension of the (square) blockwise LinearOperator. If operators contains only LinearOperator instances, block_dims may be None and the dimensions are inferred.
batch_shape Batch shape of the LinearOperator.
dtype tf.dtype of the LinearOperator.
name str, name for tf.name_scope. seed: PRNG seed; see tfp.random.sanitize_seed for details.

instance instance parameterized by trainable tf.Variables.


To build a 5x5 trainable LinearOperatorBlockDiag given LinearOperator subclasses and block_dims:

op = build_trainable_linear_operator_block(
  block_dims=[3, 2],

If operators contains only LinearOperator instances, the block_dims argument is not necessary:

# Builds a 6x6 `LinearOperatorBlockDiag` with batch shape `(4,).
op = build_trainable_linear_operator_block(
  operators=(tf.linalg.LinearOperatorDiag(tf.Variable(tf.ones((4, 3)))),

A custom operator constructor may be specified as a callable taking arguments shape and dtype, and returning a pair of callables (init_fn, apply_fn) describing a parameterized operator, with the following signatures:

raw_parameters = init_fn(seed)
linear_operator = apply_fn(raw_parameters)

For example, to define a custom initialization for a diagonal operator:

import functools

def diag_operator_with_uniform_initialization(shape, dtype):
  init_fn = functools.partial(
      samplers.uniform, shape, maxval=2., dtype=dtype)
  apply_fn = lambda scale_diag: tf.linalg.LinearOperatorDiag(
      scale_diag, is_non_singular=True)
  return init_fn, apply_fn

# Build an 8x8 `LinearOperatorBlockLowerTriangular`, with our custom diagonal
# operator in the upper left block, and `LinearOperator` subclasses in the
# lower two blocks.
op = build_trainable_linear_operator_block(
  block_dims=[4, 4],