Masking and padding with Keras

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Setup

from __future__ import absolute_import, division, print_function, unicode_literals

import numpy as np

import tensorflow as tf

from tensorflow.keras import layers

Padding sequence data

When processing sequence data, it is very common for individual samples to have different lengths. Consider the following example (text tokenized as words):

[
  ["The", "weather", "will", "be", "nice", "tomorrow"],
  ["How", "are", "you", "doing", "today"],
  ["Hello", "world", "!"]
]

After vocabulary lookup, the data might be vectorized as integers, e.g.:

[
  [83, 91, 1, 645, 1253, 927],
  [73, 8, 3215, 55, 927],
  [71, 1331, 4231]
]

The data is a 2D list where individual samples have length 6, 5, and 3 respectively. Since the input data for a deep learning model must be a single tensor (of shape e.g. (batch_size, 6, vocab_size) in this case), samples that are shorter than the longest item need to be padded with some placeholder value (alternatively, one might also truncate long samples before padding short samples).

Keras provides an API to easily truncate and pad sequences to a common length: tf.keras.preprocessing.sequence.pad_sequences.

raw_inputs = [
  [83, 91, 1, 645, 1253, 927],
  [73, 8, 3215, 55, 927],
  [711, 632, 71]
]

# By default, this will pad using 0s; it is configurable via the
# "value" parameter.
# Note that you could "pre" padding (at the beginning) or
# "post" padding (at the end).
# We recommend using "post" padding when working with RNN layers
# (in order to be able to use the 
# CuDNN implementation of the layers).
padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs,
                                                              padding='post')

print(padded_inputs)
[[  83   91    1  645 1253  927]
 [  73    8 3215   55  927    0]
 [ 711  632   71    0    0    0]]

Masking

Now that all samples have a uniform length, the model must be informed that some part of the data is actually padding and should be ignored. That mechanism is masking.

There are three ways to introduce input masks in Keras models:

Mask-generating layers: Embedding and Masking

Under the hood, these layers will create a mask tensor (2D tensor with shape (batch, sequence_length)), and attach it to the tensor output returned by the Masking or Embedding layer.

embedding = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)
masked_output = embedding(padded_inputs)

print(masked_output._keras_mask)
tf.Tensor(
[[ True  True  True  True  True  True]
 [ True  True  True  True  True False]
 [ True  True  True False False False]], shape=(3, 6), dtype=bool)
masking_layer = layers.Masking()
# Simulate the embedding lookup by expanding the 2D input to 3D,
# with embedding dimension of 10.
unmasked_embedding = tf.cast(
    tf.tile(tf.expand_dims(padded_inputs, axis=-1), [1, 1, 10]),
    tf.float32)

masked_embedding = masking_layer(unmasked_embedding)
print(masked_embedding._keras_mask)
tf.Tensor(
[[ True  True  True  True  True  True]
 [ True  True  True  True  True False]
 [ True  True  True False False False]], shape=(3, 6), dtype=bool)

As you can see from the printed result, the mask is a 2D boolean tensor with shape (batch_size, sequence_length), where each individual False entry indicates that the corresponding timestep should be ignored during processing.

Mask propagation in the Functional API and Sequential API

When using the Functional API or the Sequential API, a mask generated by an Embedding or Masking layer will be propagated through the network for any layer that is capable of using them (for example, RNN layers). Keras will automatically fetch the mask corresponding to an input and pass it to any layer that knows how to use it.

Note that in the call method of a subclassed model or layer, masks aren't automatically propagated, so you will need to manually pass a mask argument to any layer that needs one. See the section below for details.

For instance, in the following Sequential model, the LSTM layer will automatically receive a mask, which means it will ignore padded values:

model = tf.keras.Sequential([
  layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True),
  layers.LSTM(32),
])

This is also the case for the following Functional API model:

inputs = tf.keras.Input(shape=(None,), dtype='int32')
x = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)(inputs)
outputs = layers.LSTM(32)(x)

model = tf.keras.Model(inputs, outputs)

Passing mask tensors directly to layers

Layers that can handle masks (such as the LSTM layer) have a mask argument in their __call__ method.

Meanwhile, layers that produce a mask (e.g. Embedding) expose a compute_mask(input, previous_mask) method which you can call.

Thus, you can do something like this:

class MyLayer(layers.Layer):
  
  def __init__(self, **kwargs):
    super(MyLayer, self).__init__(**kwargs)
    self.embedding = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)
    self.lstm = layers.LSTM(32)
    
  def call(self, inputs):
    x = self.embedding(inputs)
    # Note that you could also prepare a `mask` tensor manually.
    # It only needs to be a boolean tensor
    # with the right shape, i.e. (batch_size, timesteps).
    mask = self.embedding.compute_mask(inputs)
    output = self.lstm(x, mask=mask)  # The layer will ignore the masked values
    return output

layer = MyLayer()
x = np.random.random((32, 10)) * 100
x = x.astype('int32')
layer(x)
<tf.Tensor: id=4730, shape=(32, 32), dtype=float32, numpy=
array([[-0.00149354, -0.00657718,  0.0043684 , ...,  0.01915387,
         0.00254279,  0.00201567],
       [-0.00874859,  0.00249364,  0.00269479, ..., -0.01414887,
         0.00511035, -0.00541363],
       [-0.00457095, -0.0097013 , -0.00557693, ...,  0.00384533,
         0.00664415,  0.00333986],
       ...,
       [-0.00762534, -0.00543655,  0.0005238 , ...,  0.01187737,
         0.00214507, -0.00063268],
       [ 0.00428915, -0.00258686,  0.00012214, ...,  0.0064177 ,
         0.00800534,  0.00203928],
       [-0.01474019, -0.00349469, -0.00311312, ..., -0.0064069 ,
         0.00472621,  0.005593  ]], dtype=float32)>

Supporting masking in your custom layers

Sometimes you may need to write layers that generate a mask (like Embedding), or layers that need to modify the current mask.

For instance, any layer that produces a tensor with a different time dimension than its input, such as a Concatenate layer that concatenates on the time dimension, will need to modify the current mask so that downstream layers will be able to properly take masked timesteps into account.

To do this, your layer should implement the layer.compute_mask() method, which produces a new mask given the input and the current mask.

Most layers don't modify the time dimension, so don't need to worry about masking. The default behavior of compute_mask() is just pass the current mask through in such cases.

Here is an example of a TemporalSplit layer that needs to modify the current mask.

class TemporalSplit(tf.keras.layers.Layer):
  """Split the input tensor into 2 tensors along the time dimension."""

  def call(self, inputs):
    # Expect the input to be 3D and mask to be 2D, split the input tensor into 2
    # subtensors along the time axis (axis 1).
    return tf.split(inputs, 2, axis=1)
    
  def compute_mask(self, inputs, mask=None):
    # Also split the mask into 2 if it presents.
    if mask is None:
      return None
    return tf.split(mask, 2, axis=1)

first_half, second_half = TemporalSplit()(masked_embedding)
print(first_half._keras_mask)
print(second_half._keras_mask)
tf.Tensor(
[[ True  True  True]
 [ True  True  True]
 [ True  True  True]], shape=(3, 3), dtype=bool)
tf.Tensor(
[[ True  True  True]
 [ True  True False]
 [False False False]], shape=(3, 3), dtype=bool)

Here is another example of a CustomEmbedding layer that is capable of generating a mask from input values:

class CustomEmbedding(tf.keras.layers.Layer):
  
  def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs):
    super(CustomEmbedding, self).__init__(**kwargs)
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.mask_zero = mask_zero
    
  def build(self, input_shape):
    self.embeddings = self.add_weight(
      shape=(self.input_dim, self.output_dim),
      initializer='random_normal',
      dtype='float32')
    
  def call(self, inputs):
    return tf.nn.embedding_lookup(self.embeddings, inputs)
  
  def compute_mask(self, inputs, mask=None):
    if not self.mask_zero:
      return None
    return tf.not_equal(inputs, 0)
  
  
layer = CustomEmbedding(10, 32, mask_zero=True)
x = np.random.random((3, 10)) * 9
x = x.astype('int32')

y = layer(x)
mask = layer.compute_mask(x)

print(mask)
tf.Tensor(
[[ True  True  True  True  True  True  True False  True  True]
 [ True  True  True  True  True  True  True  True  True  True]
 [False  True False  True  True  True False  True False  True]], shape=(3, 10), dtype=bool)

Writing layers that need mask information

Some layers are mask consumers: they accept a mask argument in call and use it to determine whether to skip certain time steps.

To write such a layer, you can simply add a mask=None argument in your call signature. The mask associated with the inputs will be passed to your layer whenever it is available.

class MaskConsumer(tf.keras.layers.Layer):
  
  def call(self, inputs, mask=None):
    ...

Recap

That is all you need to know about masking in Keras. To recap:

  • "Masking" is how layers are able to know when to skip / ignore certain timesteps in sequence inputs.
  • Some layers are mask-generators: Embedding can generate a mask from input values (if mask_zero=True), and so can the Masking layer.
  • Some layers are mask-consumers: they expose a mask argument in their __call__ method. This is the case for RNN layers.
  • In the Functional API and Sequential API, mask information is propagated automatically.
  • When writing subclassed models or when using layers in a standalone way, pass the mask arguments to layers manually.
  • You can easily write layers that modify the current mask, that generate a new mask, or that consume the mask associated with the inputs.