Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tfp.experimental.substrates.jax.bijectors.AutoregressiveNetwork

View source on GitHub

Masked Autoencoder for Distribution Estimation [Germain et al. (2015)][1].

tfp.experimental.substrates.jax.bijectors.AutoregressiveNetwork(
    params, event_shape=None, hidden_units=None, input_order='left-to-right',
    hidden_degrees='equal', activation=None, use_bias=True,
    kernel_initializer='glorot_uniform', bias_initializer='zeros',
    kernel_regularizer=None, bias_regularizer=None, kernel_constraint=None,
    bias_constraint=None, validate_args=False, **kwargs
)

A AutoregressiveNetwork takes as input a Tensor of shape [..., event_size] and returns a Tensor of shape [..., event_size, params].

The output satisfies the autoregressive property. That is, the layer is configured with some permutation ord of {0, ..., event_size-1} (i.e., an ordering of the input dimensions), and the output output[batch_idx, i, ...] for input dimension i depends only on inputs x[batch_idx, j] where ord(j) < ord(i). The autoregressive property allows us to use output[batch_idx, i] to parameterize conditional distributions: p(x[batch_idx, i] | x[batch_idx, j] for ord(j) < ord(i)) which give us a tractable distribution over input x[batch_idx]: p(x[batch_idx]) = prod_i p(x[batch_idx, ord(i)] | x[batch_idx, ord(0:i)])

For example, when params is 2, the output of the layer can parameterize the location and log-scale of an autoregressive Gaussian distribution.

Example

# Generate data -- as in Figure 1 in [Papamakarios et al. (2017)][2]).
n = 2000
x2 = np.random.randn(n).astype(dtype=np.float32) * 2.
x1 = np.random.randn(n).astype(dtype=np.float32) + (x2 * x2 / 4.)
data = np.stack([x1, x2], axis=-1)

# Density estimation with MADE.
made = tfb.AutoregressiveNetwork(params=2, hidden_units=[10, 10])

distribution = tfd.TransformedDistribution(
    distribution=tfd.Normal(loc=0., scale=1.),
    bijector=tfb.MaskedAutoregressiveFlow(made),
    event_shape=[2])

# Construct and fit model.
x_ = tfkl.Input(shape=(2,), dtype=tf.float32)
log_prob_ = distribution.log_prob(x_)
model = tfk.Model(x_, log_prob_)

model.compile(optimizer=tf.optimizers.Adam(),
              loss=lambda _, log_prob: -log_prob)

batch_size = 25
model.fit(x=data,
          y=np.zeros((n, 0), dtype=np.float32),
          batch_size=batch_size,
          epochs=1,
          steps_per_epoch=1,  # Usually `n // batch_size`.
          shuffle=True,
          verbose=True)

# Use the fitted distribution.
distribution.sample((3, 1))
distribution.log_prob(np.ones((3, 2), dtype=np.float32))

Examples: Handling Rank-2+ Tensors

AutoregressiveNetwork can be used as a building block to achieve different autoregressive structures over rank-2+ tensors. For example, suppose we want to build an autoregressive distribution over images with dimension [weight, height, channels] with channels = 3:

  1. We can parameterize a 'fully autoregressive' distribution, with cross-channel and within-pixel autoregressivity:
    r0    g0   b0     r0    g0   b0       r0   g0    b0
    ^   ^      ^         ^   ^   ^         ^      ^   ^
    |  /  ____/           \  |  /           \____  \  |
    | /__/                 \ | /                 \__\ |
    r1    g1   b1     r1 <- g1   b1       r1   g1 <- b1
                                         ^          |
                                          \_________/
as:
# Generate random images for training data.
images = np.random.uniform(size=(100, 8, 8, 3)).astype(np.float32)
n, width, height, channels = images.shape

# Reshape images to achieve desired autoregressivity.
event_shape = [height * width * channels]
reshaped_images = tf.reshape(images, [n, event_shape])

# Density estimation with MADE.
made = tfb.AutoregressiveNetwork(params=2, event_shape=event_shape,
                                 hidden_units=[20, 20], activation='relu')
distribution = tfd.TransformedDistribution(
    distribution=tfd.Normal(loc=0., scale=1.),
    bijector=tfb.MaskedAutoregressiveFlow(made),
    event_shape=event_shape)

# Construct and fit model.
x_ = tfkl.Input(shape=event_shape, dtype=tf.float32)
log_prob_ = distribution.log_prob(x_)
model = tfk.Model(x_, log_prob_)

model.compile(optimizer=tf.optimizers.Adam(),
              loss=lambda _, log_prob: -log_prob)

batch_size = 10
model.fit(x=data,
          y=np.zeros((n, 0), dtype=np.float32),
          batch_size=batch_size,
          epochs=10,
          steps_per_epoch=n // batch_size,
          shuffle=True,
          verbose=True)

# Use the fitted distribution.
distribution.sample((3, 1))
distribution.log_prob(np.ones((5, 8, 8, 3), dtype=np.float32))
  1. We can parameterize a distribution with neither cross-channel nor within-pixel autoregressivity:
    r0    g0   b0
    ^     ^    ^
    |     |    |
    |     |    |
    r1    g1   b1
as:
# Generate fake images.
images = np.random.choice([0, 1], size=(100, 8, 8, 3))
n, width, height, channels = images.shape

# Reshape images to achieve desired autoregressivity.
reshaped_images = np.transpose(
    np.reshape(images, [n, width * height, channels]),
    axes=[0, 2, 1])

made = tfb.AutoregressiveNetwork(params=1, event_shape=[width * height],
                                 hidden_units=[20, 20], activation='relu')

# Density estimation with MADE.
#
# NOTE: Parameterize an autoregressive distribution over an event_shape of
# [channels, width * height], with univariate Bernoulli conditional
# distributions.
distribution = tfd.Autoregressive(
    lambda x: tfd.Independent(
        tfd.Bernoulli(logits=tf.unstack(made(x), axis=-1)[0],
                      dtype=tf.float32),
        reinterpreted_batch_ndims=2),
    sample0=tf.zeros([channels, width * height], dtype=tf.float32))

# Construct and fit model.
x_ = tfkl.Input(shape=(channels, width * height), dtype=tf.float32)
log_prob_ = distribution.log_prob(x_)
model = tfk.Model(x_, log_prob_)

model.compile(optimizer=tf.optimizers.Adam(),
              loss=lambda _, log_prob: -log_prob)

batch_size = 10
model.fit(x=reshaped_images,
          y=np.zeros((n, 0), dtype=np.float32),
          batch_size=batch_size,
          epochs=10,
          steps_per_epoch=n // batch_size,
          shuffle=True,
          verbose=True)

distribution.sample(7)
distribution.log_prob(np.ones((4, 8, 8, 3), dtype=np.float32))
Note that one set of weights is shared for the mapping for each channel
from image to distribution parameters -- i.e., the mapping
`layer(reshaped_images[..., channel, :])`, where `channel` is 0, 1, or 2.

To use separate weights for each channel, we could construct an
`AutoregressiveNetwork` and `TransformedDistribution` for each channel,
and combine them with a `tfd.Blockwise` distribution.

References

[1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: Masked Autoencoder for Distribution Estimation. In International Conference on Machine Learning, 2015. https://arxiv.org/abs/1502.03509

[2]: George Papamakarios, Theo Pavlakou, Iain Murray, Masked Autoregressive Flow for Density Estimation. In Neural Information Processing Systems, 2017. https://arxiv.org/abs/1705.07057

Arguments:

  • params: Python integer specifying the number of parameters to output per input.
  • event_shape: Python list-like of positive integers (or a single int), specifying the shape of the input to this layer, which is also the event_shape of the distribution parameterized by this layer. Currently only rank-1 shapes are supported. That is, event_shape must be a single integer. If not specified, the event shape is inferred when this layer is first called or built.
  • hidden_units: Python list-like of non-negative integers, specifying the number of units in each hidden layer.
  • input_order: Order of degrees to the input units: 'random', 'left-to-right', 'right-to-left', or an array of an explicit order. For example, 'left-to-right' builds an autoregressive model: p(x) = p(x1) p(x2 | x1) ... p(xD | x<D). Default: 'left-to-right'.
  • hidden_degrees: Method for assigning degrees to the hidden units: 'equal', 'random'. If 'equal', hidden units in each layer are allocated equally (up to a remainder term) to each degree. Default: 'equal'.
  • activation: An activation function. See tf.keras.layers.Dense. Default: None.
  • use_bias: Whether or not the dense layers constructed in this layer should have a bias term. See tf.keras.layers.Dense. Default: True.
  • kernel_initializer: Initializer for the Dense kernel weight matrices. Default: 'glorot_uniform'.
  • bias_initializer: Initializer for the Dense bias vectors. Default: 'zeros'.
  • kernel_regularizer: Regularizer function applied to the Dense kernel weight matrices. Default: None.
  • bias_regularizer: Regularizer function applied to the Dense bias weight vectors. Default: None.
  • kernel_constraint: Constraint function applied to the Dense kernel weight matrices. Default: None.
  • bias_constraint: Constraint function applied to the Dense bias weight vectors. Default: None.
  • validate_args: Python bool, default False. When True, layer parameters are checked for validity despite possibly degrading runtime performance. When False invalid inputs may silently render incorrect outputs.
  • **kwargs: Additional keyword arguments passed to this layer (but not to the tf.keras.layer.Dense layers constructed by this layer).

Attributes:

  • event_shape
  • params

Methods

build

View source

build(
    input_shape
)

See tfkl.Layer.build.

call

View source

call(
    x
)

See tfkl.Layer.call.

compute_output_shape

View source

compute_output_shape(
    input_shape
)

See tfkl.Layer.compute_output_shape.