Image captioning with visual attention

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Given an image like the example below, your goal is to generate a caption such as "a surfer riding on a wave".

A man surfing, from wikimedia

The model architecture used here is inspired by Show, Attend and Tell: Neural Image Caption Generation with Visual Attention, but has been updated to use a 2-layer Transformer-decoder. To get the most out of this tutorial you should have some experience with text generation, seq2seq models & attention, or transformers.

The model architecture built in this tutorial is shown below. Features are extracted from the image, and passed to the cross-attention layers of the Transformer-decoder.

The model architecture

The transformer decoder is mainly built from attention layers. It uses self-attention to process the sequence being generated, and it uses cross-attention to attend to the image.

By inspecting the attention weights of the cross attention layers you will see what parts of the image the model is looking at as it generates words.

Prediction

This notebook is an end-to-end example. When you run the notebook, it downloads a dataset, extracts and caches the image features, and trains a decoder model. It then uses the model to generate captions on new images.

Setup

apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2
E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?
pip uninstall -y tensorflow estimator keras
pip install -U tensorflow_text tensorflow tensorflow_datasets
pip install einops

This tutorial uses lots of imports, mostly for loading the dataset(s).

2022-12-14 06:20:16.768617: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 06:20:16.768716: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 06:20:16.768725: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

[Optional] Data handling

This section downloads a captions dataset and prepares it for training. It tokenizes the input text, and caches the results of running all the images through a pretrained feature-extractor model. It's not critical to understand everything in this section.

Data ready for training

After those preprocessing steps, here are the datasets:

train_ds = load_dataset('train_cache')
test_ds = load_dataset('test_cache')
train_ds.element_spec
((TensorSpec(shape=(None, 7, 7, 576), dtype=tf.float32, name=None),
  TensorSpec(shape=(None, None), dtype=tf.int64, name=None)),
 TensorSpec(shape=(None, None), dtype=tf.int64, name=None))

The dataset now returns (input, label) pairs suitable for training with keras. The inputs are (images, input_tokens) pairs. The images have been processed with the feature-extractor model. For each location in the input_tokens the model looks at the text so far and tries to predict the next which is lined up at the same location in the labels.

for (inputs, ex_labels) in train_ds.take(1):
  (ex_img, ex_in_tok) = inputs

print(ex_img.shape)
print(ex_in_tok.shape)
print(ex_labels.shape)
(32, 7, 7, 576)
(32, 19)
(32, 19)

The input tokens and the labels are the same, just shifted by 1 step:

print(ex_in_tok[0].numpy())
print(ex_labels[0].numpy())
[  3   2  10  35   5   6 344  11   2  39   0   0   0   0   0   0   0   0
   0]
[  2  10  35   5   6 344  11   2  39   4   0   0   0   0   0   0   0   0
   0]

A Transformer decoder model

This model assumes that the pretrained image encoder is sufficient, and just focuses on building the text decoder. This tutorial uses a 2-layer Transformer-decoder.

The implementations are almost identical to those in the Transformers tutorial. Refer back to it for more details.

The Transformer encoder and decoder.

The model will be implemented in three main parts:

  1. Input - The token embedding and positional encoding (SeqEmbedding).
  2. Decoder - A stack of transformer decoder layers (DecoderLayer) where each contains:
    1. A causal self attention later (CausalSelfAttention), where each output location can attend to the output so far.
    2. A cross attention layer (CrossAttention) where each output location can attend to the input image.
    3. A feed forward network (FeedForward) layer which further processes each output location independently.
  3. Output - A multiclass-classification over the output vocabulary.

Input

The input text has already been split up into tokens and converted to sequences of IDs.

Remember that unlike a CNN or RNN the Transformer's attention layers are invariant to the order of the sequence. Without some positional input, it just sees an unordered set not a sequence. So in addition to a simple vector embedding for each token ID, the embedding layer will also include an embedding for each position in the sequence.

The SeqEmbedding layer defined below:

  • It looks up the embedding vector for each token.
  • It looks up an embedding vector for each sequence location.
  • It adds the two together.
  • It uses mask_zero=True to initialize the keras-masks for the model.
class SeqEmbedding(tf.keras.layers.Layer):
  def __init__(self, vocab_size, max_length, depth):
    super().__init__()
    self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth)

    self.token_embedding = tf.keras.layers.Embedding(
        input_dim=vocab_size,
        output_dim=depth,
        mask_zero=True)

    self.add = tf.keras.layers.Add()

  def call(self, seq):
    seq = self.token_embedding(seq) # (batch, seq, depth)

    x = tf.range(tf.shape(seq)[1])  # (seq)
    x = x[tf.newaxis, :]  # (1, seq)
    x = self.pos_embedding(x)  # (1, seq, depth)

    return self.add([seq,x])

Decoder

The decoder is a standard Transformer-decoder, it contains a stack of DecoderLayers where each contains three sublayers: a CausalSelfAttention, a CrossAttention, and aFeedForward. The implementations are almost identical to the Transformer tutorial, refer to it for more details.

The CausalSelfAttention layer is below:

class CausalSelfAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    # Use Add instead of + so the keras mask propagates through.
    self.add = tf.keras.layers.Add() 
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    attn = self.mha(query=x, value=x,
                    use_causal_mask=True)
    x = self.add([x, attn])
    return self.layernorm(x)

The CrossAttention layer is below. Note the use of return_attention_scores.

class CrossAttention(tf.keras.layers.Layer):
  def __init__(self,**kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.add = tf.keras.layers.Add() 
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x, y, **kwargs):
    attn, attention_scores = self.mha(
             query=x, value=y,
             return_attention_scores=True)

    self.last_attention_scores = attention_scores

    x = self.add([x, attn])
    return self.layernorm(x)

The FeedForward layer is below. Remember that a layers.Dense layer is applied to the last axis of the input. The input will have a shape of (batch, sequence, channels), so it automatically applies pointwise across the batch and sequence axes.

class FeedForward(tf.keras.layers.Layer):
  def __init__(self, units, dropout_rate=0.1):
    super().__init__()
    self.seq = tf.keras.Sequential([
        tf.keras.layers.Dense(units=2*units, activation='relu'),
        tf.keras.layers.Dense(units=units),
        tf.keras.layers.Dropout(rate=dropout_rate),
    ])

    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    x = x + self.seq(x)
    return self.layernorm(x)

Next arrange these three layers into a larger DecoderLayer. Each decoder layer applies the three smaller layers in sequence. After each sublayer the shape of out_seq is (batch, sequence, channels). The decoder layer also returns the attention_scores for later visualizations.

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, units, num_heads=1, dropout_rate=0.1):
    super().__init__()

    self.self_attention = CausalSelfAttention(num_heads=num_heads,
                                              key_dim=units,
                                              dropout=dropout_rate)
    self.cross_attention = CrossAttention(num_heads=num_heads,
                                          key_dim=units,
                                          dropout=dropout_rate)
    self.ff = FeedForward(units=units, dropout_rate=dropout_rate)


  def call(self, inputs, training=False):
    in_seq, out_seq = inputs

    # Text input
    out_seq = self.self_attention(out_seq)

    out_seq = self.cross_attention(out_seq, in_seq)

    self.last_attention_scores = self.cross_attention.last_attention_scores

    out_seq = self.ff(out_seq)

    return out_seq

Output

At minimum the output layer needs a layers.Dense layer to generate logit-predictions for each token at each location.

But there are a few other features you can add to make this work a little better:

  1. Handle bad tokens: The model will be generating text. It should never generate a pad, unknown, or start token ('', '[UNK]', '[START]'). So set the bias for these to a large negative value.

  2. Smart initialization: The default initialization of a dense layer will give a model that initially predicts each token with almost uniform likelihood. The actual token distribution is far from uniform. The optimal value for the initial bias of the output layer is the log of the probability of each token. So include an adapt method to count the tokens and set the optimal initial bias. This reduces the initial loss from the entropy of the uniform distribution (log(vocabulary_size)) to the marginal entropy of the distribution (-p*log(p)).

The smart initialization will significantly reduce the initial loss:

output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]'))
# This might run a little faster if the dataset didn't also have to load the image data.
output_layer.adapt(train_ds.map(lambda inputs, labels: labels))
100%|██████████| 938/938 [00:02<00:00, 343.51it/s]
Uniform entropy: 8.52
Marginal entropy: 5.29

Build the model

To build the model, you need to combine several parts:

  1. The image feature_extractor and the text tokenizer and.
  2. The seq_embedding layer, to convert batches of token-IDs to vectors (batch, sequence, channels).
  3. The stack of DecoderLayers layers that will process the text and image data.
  4. The output_layer which returns a pointwise prediction of what the next word should be.
class Captioner(tf.keras.Model):
  @classmethod
  def add_method(cls, fun):
    setattr(cls, fun.__name__, fun)
    return fun

  def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1,
               units=256, max_length=50, num_heads=1, dropout_rate=0.1):
    super().__init__()
    self.feature_extractor = feature_extractor
    self.tokenizer = tokenizer
    self.word_to_index = tf.keras.layers.StringLookup(
        mask_token="",
        vocabulary=tokenizer.get_vocabulary())
    self.index_to_word = tf.keras.layers.StringLookup(
        mask_token="",
        vocabulary=tokenizer.get_vocabulary(),
        invert=True) 

    self.seq_embedding = SeqEmbedding(
        vocab_size=tokenizer.vocabulary_size(),
        depth=units,
        max_length=max_length)

    self.decoder_layers = [
        DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate)
        for n in range(num_layers)]

    self.output_layer = output_layer

When you call the model, for training, it receives an image, txt pair. To make this function more usable, be flexible about the input:

  • If the image has 3 channels run it through the feature_extractor. Otherwise assume that it has been already. Similarly
  • If the text has dtype tf.string run it through the tokenizer.

After that running the model is only a few steps:

  1. Flatten the extracted image features, so they can be input to the decoder layers.
  2. Look up the token embeddings.
  3. Run the stack of DecoderLayers, on the image features and text embeddings.
  4. Run the output layer to predict the next token at each position.
@Captioner.add_method
  def call(self, inputs):
    image, txt = inputs

    if image.shape[-1] == 3:
      # Apply the feature-extractor, if you get an RGB image.
      image = self.feature_extractor(image)

    # Flatten the feature map
    image = einops.rearrange(image, 'b h w c -> b (h w) c')


    if txt.dtype == tf.string:
      # Apply the tokenizer if you get string inputs.
      txt = tokenizer(txt)

    txt = self.seq_embedding(txt)

    # Look at the image
    for dec_layer in self.decoder_layers:
      txt = dec_layer(inputs=(image, txt))

    txt = self.output_layer(txt)

    return txt
model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer,
                  units=256, dropout_rate=0.5, num_layers=2, num_heads=2)

Generate captions

Before getting into training, write a bit of code to generate captions. You'll use this to see how training is progressing.

Start by downloading a test image:

image_url = 'https://tensorflow.org/images/surf.jpg'
image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
image = load_image(image_path)
Downloading data from https://tensorflow.org/images/surf.jpg
64400/64400 [==============================] - 0s 0us/step

To caption an image with this model:

  • Extract the img_features
  • Initialize the list of output tokens with a [START] token.
  • Pass img_features and tokens into the model.
    • It returns a list of logits.
    • Choose the next token based on those logits.
    • Add it to the list of tokens, and continue the loop.
    • If it generates an '[END]' token, break out of the loop.

So add a "simple" method to do just that:

@Captioner.add_method
def simple_gen(self, image, temperature=1):
  initial = self.word_to_index([['[START]']]) # (batch, sequence)
  img_features = self.feature_extractor(image[tf.newaxis, ...])

  tokens = initial # (batch, sequence)
  for n in range(50):
    preds = self((img_features, tokens)).numpy()  # (batch, sequence, vocab)
    preds = preds[:,-1, :]  #(batch, vocab)
    if temperature==0:
        next = tf.argmax(preds, axis=-1)[:, tf.newaxis]  # (batch, 1)
    else:
        next = tf.random.categorical(preds/temperature, num_samples=1)  # (batch, 1)
    tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) 

    if next[0] == self.word_to_index('[END]'):
      break
  words = index_to_word(tokens[0, 1:-1])
  result = tf.strings.reduce_join(words, axis=-1, separator=' ')
  return result.numpy().decode()

Here are some generated captions for that image, the model's untrained, so they don't make much sense yet:

for t in (0.0, 0.5, 1.0):
  result = model.simple_gen(image, temperature=t)
  print(result)
a
a a the in a a young
car ball wave motorbike man off day a carries cones on bench a a wearing is the squinting car the points block walls street two

The temperature parameter allows you to interpolate between 3 modes:

  1. Greedy decoding (temperature=0.0) - Chooses the most likely next token at each step.
  2. Random sampling according to the logits (temperature=1.0).
  3. Uniform random sampling (temperature >> 1.0).

Since the model is untrained, and it used the frequency-based initialization, the "greedy" output (first) usually only contains the most common tokens: ['a', '.', '[END]'].

Train

To train the model you'll need several additional components:

  • The Loss and metrics
  • The Optimizer
  • Optional Callbacks

Losses and metrics

Here's an implementation of a masked loss and accuracy:

When calculating the mask for the loss, note the loss < 1e8. This term discards the artificial, impossibly high losses for the banned_tokens.

def masked_loss(labels, preds):  
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)

  mask = (labels != 0) & (loss < 1e8) 
  mask = tf.cast(mask, loss.dtype)

  loss = loss*mask
  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss

def masked_acc(labels, preds):
  mask = tf.cast(labels!=0, tf.float32)
  preds = tf.argmax(preds, axis=-1)
  labels = tf.cast(labels, tf.int64)
  match = tf.cast(preds == labels, mask.dtype)
  acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask)
  return acc

Callbacks

For feedback during training setup a keras.callbacks.Callback to generate some captions for the surfer image at the end of each epoch.

class GenerateText(tf.keras.callbacks.Callback):
  def __init__(self):
    image_url = 'https://tensorflow.org/images/surf.jpg'
    image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
    self.image = load_image(image_path)

  def on_epoch_end(self, epochs=None, logs=None):
    print()
    print()
    for t in (0.0, 0.5, 1.0):
      result = self.model.simple_gen(self.image, temperature=t)
      print(result)
    print()

It generates three output strings, like the earlier example, like before the first is "greedy", choosing the argmax of the logits at each step.

g = GenerateText()
g.model = model
g.on_epoch_end(0)
a
a
in in through

Also use callbacks.EarlyStopping to terminate training when the model starts to overfit.

callbacks = [
    GenerateText(),
    tf.keras.callbacks.EarlyStopping(
        patience=5, restore_best_weights=True)]

Train

Configure and execute the training.

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
           loss=masked_loss,
           metrics=[masked_acc])

For more frequent reporting, use the Dataset.repeat() method, and set the steps_per_epoch and validation_steps arguments to Model.fit.

With this setup on Flickr8k a full pass over the dataset is 900+ batches, but below the reporting-epochs are 100 steps.

history = model.fit(
    train_ds.repeat(),
    steps_per_epoch=100,
    validation_data=test_ds.repeat(),
    validation_steps=20,
    epochs=100,
    callbacks=callbacks)
Epoch 1/100
100/100 [==============================] - ETA: 0s - loss: 4.9897 - masked_acc: 0.2048

a man in a man in the in a
a man on a black white down a on a rock
a shirt rafting by narrow walking path

100/100 [==============================] - 24s 140ms/step - loss: 4.9897 - masked_acc: 0.2048 - val_loss: 4.6019 - val_masked_acc: 0.2472
Epoch 2/100
 99/100 [============================>.] - ETA: 0s - loss: 4.6280 - masked_acc: 0.2525

a man in a red is in the water
a man is in a with the water
the ground on a playing in the air grass

100/100 [==============================] - 6s 61ms/step - loss: 4.6247 - masked_acc: 0.2530 - val_loss: 4.3118 - val_masked_acc: 0.2762
Epoch 3/100
100/100 [==============================] - ETA: 0s - loss: 4.4231 - masked_acc: 0.2753

a man in a red is in the water
a man in a pool in the water
the girl rock colored water in the hats is sitting on a dogs

100/100 [==============================] - 5s 55ms/step - loss: 4.4231 - masked_acc: 0.2753 - val_loss: 4.2734 - val_masked_acc: 0.2816
Epoch 4/100
100/100 [==============================] - ETA: 0s - loss: 4.2494 - masked_acc: 0.2938

a man in a red and white and white dog is running in the water
a boy is jumping on the snow
a little man posing yellow is jumping in the air

100/100 [==============================] - 5s 54ms/step - loss: 4.2494 - masked_acc: 0.2938 - val_loss: 4.0899 - val_masked_acc: 0.3003
Epoch 5/100
 98/100 [============================>.] - ETA: 0s - loss: 4.1317 - masked_acc: 0.3078

a man is jumping in the water
two people are playing in a pool
a black playing in a surfboard while frisbee a beach

100/100 [==============================] - 5s 49ms/step - loss: 4.1321 - masked_acc: 0.3076 - val_loss: 3.9376 - val_masked_acc: 0.3162
Epoch 6/100
100/100 [==============================] - ETA: 0s - loss: 4.0057 - masked_acc: 0.3185

a man in a red shirt is running in the water
a man is walking on a yellow shirt and a pool
a boy is riding over the beach and shore

100/100 [==============================] - 5s 54ms/step - loss: 4.0057 - masked_acc: 0.3185 - val_loss: 3.7882 - val_masked_acc: 0.3246
Epoch 7/100
 99/100 [============================>.] - ETA: 0s - loss: 3.9497 - masked_acc: 0.3253

a man in a blue shirt is jumping in the water
a girl in a blue is standing in the snow
a big man is walking from the background

100/100 [==============================] - 5s 51ms/step - loss: 3.9505 - masked_acc: 0.3254 - val_loss: 3.7504 - val_masked_acc: 0.3260
Epoch 8/100
 99/100 [============================>.] - ETA: 0s - loss: 3.8685 - masked_acc: 0.3301

a man is in a red shirt is in the water
a boy in a blue shirt is doing a pool
the man is and yellow turn in the billboards going at a

100/100 [==============================] - 5s 50ms/step - loss: 3.8707 - masked_acc: 0.3300 - val_loss: 3.6968 - val_masked_acc: 0.3369
Epoch 9/100
 99/100 [============================>.] - ETA: 0s - loss: 3.8177 - masked_acc: 0.3343

a man is jumping into the water
a man is jumping through the water
a boys wearing a pink into the water trees and wave

100/100 [==============================] - 5s 47ms/step - loss: 3.8152 - masked_acc: 0.3346 - val_loss: 3.6542 - val_masked_acc: 0.3332
Epoch 10/100
100/100 [==============================] - ETA: 0s - loss: 3.7342 - masked_acc: 0.3388

a man in a red shirt is jumping into the water
a man in a red jacket is riding a pool
a boy in a green dress in a pool

100/100 [==============================] - 5s 50ms/step - loss: 3.7342 - masked_acc: 0.3388 - val_loss: 3.5882 - val_masked_acc: 0.3445
Epoch 11/100
 99/100 [============================>.] - ETA: 0s - loss: 3.6388 - masked_acc: 0.3448

a man in a red shirt is jumping into the water
the person is in the water
a man in a body of the a chocolate at the snow

100/100 [==============================] - 5s 50ms/step - loss: 3.6357 - masked_acc: 0.3449 - val_loss: 3.5388 - val_masked_acc: 0.3480
Epoch 12/100
100/100 [==============================] - ETA: 0s - loss: 3.5974 - masked_acc: 0.3470

a man is jumping into the water
a man in a red jacket is in the water
two young in a blue purple running on a large snow

100/100 [==============================] - 5s 47ms/step - loss: 3.5974 - masked_acc: 0.3470 - val_loss: 3.5301 - val_masked_acc: 0.3547
Epoch 13/100
100/100 [==============================] - ETA: 0s - loss: 3.5613 - masked_acc: 0.3505

a man is jumping into the water
a skateboarder is pulling a wave
a skateboard boy with a red large black black brick n girl in green toy with surfboard

100/100 [==============================] - 5s 49ms/step - loss: 3.5613 - masked_acc: 0.3505 - val_loss: 3.4566 - val_masked_acc: 0.3602
Epoch 14/100
 99/100 [============================>.] - ETA: 0s - loss: 3.5301 - masked_acc: 0.3564

a man in a red shirt is jumping into the water
a man wearing a blue and blue hat is riding a blue wave
a boy jumping front of a rocks next to a on skating through a water

100/100 [==============================] - 5s 55ms/step - loss: 3.5272 - masked_acc: 0.3568 - val_loss: 3.4105 - val_masked_acc: 0.3564
Epoch 15/100
100/100 [==============================] - ETA: 0s - loss: 3.5040 - masked_acc: 0.3572

a man in a red shirt is jumping into the water
a young boy is jumping into the water
a person is playing in the ocean

100/100 [==============================] - 4s 45ms/step - loss: 3.5040 - masked_acc: 0.3572 - val_loss: 3.4809 - val_masked_acc: 0.3456
Epoch 16/100
 99/100 [============================>.] - ETA: 0s - loss: 3.4727 - masked_acc: 0.3573

a man in a red shirt is jumping into the water
a man in a white body of water
a surfer has his a river

100/100 [==============================] - 5s 46ms/step - loss: 3.4724 - masked_acc: 0.3573 - val_loss: 3.4348 - val_masked_acc: 0.3581
Epoch 17/100
 98/100 [============================>.] - ETA: 0s - loss: 3.4571 - masked_acc: 0.3614

a man in a red shirt is jumping into the water
a boy in a red is playing in the water
a dog is with a swimming water in the suit about to someone in a lake

100/100 [==============================] - 5s 53ms/step - loss: 3.4561 - masked_acc: 0.3617 - val_loss: 3.3433 - val_masked_acc: 0.3684
Epoch 18/100
 99/100 [============================>.] - ETA: 0s - loss: 3.4082 - masked_acc: 0.3662

a man in a red shirt is swimming pool
a man is on a wave
the boy is riding he jumps into a wave

100/100 [==============================] - 5s 48ms/step - loss: 3.4063 - masked_acc: 0.3662 - val_loss: 3.3187 - val_masked_acc: 0.3669
Epoch 19/100
 99/100 [============================>.] - ETA: 0s - loss: 3.3683 - masked_acc: 0.3698

a man in a red shirt is riding a wave
a person in a red shirt is playing in the water
the person jumping in a bright are off his head off pool

100/100 [==============================] - 5s 52ms/step - loss: 3.3700 - masked_acc: 0.3695 - val_loss: 3.2272 - val_masked_acc: 0.3723
Epoch 20/100
 98/100 [============================>.] - ETA: 0s - loss: 3.3192 - masked_acc: 0.3692

a man in a red shirt is swimming pool
a person in a red jacket is in the ocean
a family rides in his is fence in a pool

100/100 [==============================] - 5s 47ms/step - loss: 3.3165 - masked_acc: 0.3697 - val_loss: 3.2670 - val_masked_acc: 0.3670
Epoch 21/100
 99/100 [============================>.] - ETA: 0s - loss: 3.2656 - masked_acc: 0.3744

a man in a red shirt is jumping into the water
a man in a yellow shirt is playing in the ocean
a cyclist is reaching for the ocean is pool near the backyard in the water

100/100 [==============================] - 5s 51ms/step - loss: 3.2641 - masked_acc: 0.3745 - val_loss: 3.2880 - val_masked_acc: 0.3627
Epoch 22/100
100/100 [==============================] - ETA: 0s - loss: 3.2363 - masked_acc: 0.3776

a man in a red shirt is riding a wave
a person is jumping through the water
a man in a blue tank wave while wearing a yellow waves

100/100 [==============================] - 5s 48ms/step - loss: 3.2363 - masked_acc: 0.3776 - val_loss: 3.2540 - val_masked_acc: 0.3643
Epoch 23/100
100/100 [==============================] - ETA: 0s - loss: 3.2261 - masked_acc: 0.3769

a man in a red shirt is riding a wave
a man rides a wave
a person is eating a yellow surfboard

100/100 [==============================] - 4s 44ms/step - loss: 3.2261 - masked_acc: 0.3769 - val_loss: 3.1672 - val_masked_acc: 0.3839
Epoch 24/100
100/100 [==============================] - ETA: 0s - loss: 3.2116 - masked_acc: 0.3827

a man in a red shirt is riding a wave
a person in a red shirt is riding a wave
a person wearing a hat riding his head

100/100 [==============================] - 5s 47ms/step - loss: 3.2116 - masked_acc: 0.3827 - val_loss: 3.1425 - val_masked_acc: 0.3820
Epoch 25/100
 98/100 [============================>.] - ETA: 0s - loss: 3.2169 - masked_acc: 0.3789

a man in a red shirt is jumping into the water
a man is in the water
a person in a red kayak splashing in a blue competitor sunglasses in midair

100/100 [==============================] - 5s 51ms/step - loss: 3.2167 - masked_acc: 0.3786 - val_loss: 3.1347 - val_masked_acc: 0.3790
Epoch 26/100
100/100 [==============================] - ETA: 0s - loss: 3.1825 - masked_acc: 0.3804

a man in a red shirt is riding a wave
a man in a red and a blue surfboard
a surfer in bright orange collar swimming pool

100/100 [==============================] - 5s 48ms/step - loss: 3.1825 - masked_acc: 0.3804 - val_loss: 3.1248 - val_masked_acc: 0.3781
Epoch 27/100
 98/100 [============================>.] - ETA: 0s - loss: 3.1475 - masked_acc: 0.3823

a man in a red and white surfboard is swimming pool
a surfer in a orange and white kayak
a surfboard rides a wave on the edge of the wave and going into the surfboard

100/100 [==============================] - 5s 51ms/step - loss: 3.1478 - masked_acc: 0.3824 - val_loss: 3.1369 - val_masked_acc: 0.3747
Epoch 28/100
100/100 [==============================] - ETA: 0s - loss: 3.1581 - masked_acc: 0.3804

a man in a red wetsuit is surfing
a man in a blue shirt and white is riding a wave
a man wearing a life shorts is neighborhood wave in the pool

100/100 [==============================] - 5s 50ms/step - loss: 3.1581 - masked_acc: 0.3804 - val_loss: 3.1596 - val_masked_acc: 0.3775
Epoch 29/100
 99/100 [============================>.] - ETA: 0s - loss: 3.0912 - masked_acc: 0.3881

a man is surfing in the water
a person is wearing a red wave
a man rides opposing board on the water

100/100 [==============================] - 5s 46ms/step - loss: 3.0922 - masked_acc: 0.3880 - val_loss: 3.1156 - val_masked_acc: 0.3792
Epoch 30/100
100/100 [==============================] - ETA: 0s - loss: 3.0313 - masked_acc: 0.3966

a man in a red shirt is surfing
a person in a red shirt is skiing down a wave
a child in a bright orange wetsuit rides wave

100/100 [==============================] - 5s 47ms/step - loss: 3.0313 - masked_acc: 0.3966 - val_loss: 3.1330 - val_masked_acc: 0.3788
Epoch 31/100
100/100 [==============================] - ETA: 0s - loss: 3.0351 - masked_acc: 0.3909

a man in a red shirt is surfing
a child is riding a wave in a wave
a man with the guy riding its mouth in the orange surfboard

100/100 [==============================] - 5s 48ms/step - loss: 3.0351 - masked_acc: 0.3909 - val_loss: 3.1118 - val_masked_acc: 0.3808
Epoch 32/100
 99/100 [============================>.] - ETA: 0s - loss: 3.0301 - masked_acc: 0.3957

a man in a red wetsuit is surfing
a person in a pink wetsuit is swimming pool
a person rides a surfer is in been wave

100/100 [==============================] - 5s 47ms/step - loss: 3.0313 - masked_acc: 0.3954 - val_loss: 3.0272 - val_masked_acc: 0.3886
Epoch 33/100
 98/100 [============================>.] - ETA: 0s - loss: 3.0047 - masked_acc: 0.3967

a man in a red wetsuit is surfing
a man is in a black and white board with a red surfboard in a red and white surfboard
a surfer sits on the pool

100/100 [==============================] - 5s 52ms/step - loss: 3.0062 - masked_acc: 0.3963 - val_loss: 3.0075 - val_masked_acc: 0.3915
Epoch 34/100
100/100 [==============================] - ETA: 0s - loss: 3.0074 - masked_acc: 0.3979

a man in a red wetsuit is surfing
a surfer in a white surfboard is surfing
a surfer is wearing a a wave with a man is sliding

100/100 [==============================] - 5s 47ms/step - loss: 3.0074 - masked_acc: 0.3979 - val_loss: 2.9537 - val_masked_acc: 0.3980
Epoch 35/100
100/100 [==============================] - ETA: 0s - loss: 2.9716 - masked_acc: 0.3978

a man in a red wetsuit is surfing
a person in a red and white surfboard
a surfer in a yellow kayak is watching a wave

100/100 [==============================] - 4s 44ms/step - loss: 2.9716 - masked_acc: 0.3978 - val_loss: 3.0776 - val_masked_acc: 0.3780
Epoch 36/100
 99/100 [============================>.] - ETA: 0s - loss: 2.9980 - masked_acc: 0.3957

a surfer is riding a wave
a person in a black wetsuit is in the waves
a person riding on a splash

100/100 [==============================] - 4s 43ms/step - loss: 2.9988 - masked_acc: 0.3956 - val_loss: 3.0296 - val_masked_acc: 0.3875
Epoch 37/100
 99/100 [============================>.] - ETA: 0s - loss: 2.9871 - masked_acc: 0.3948

a man in a red wetsuit is surfing
a man in a red wetsuit is riding a wave
the person wearing a red surfboard lays in the waves

100/100 [==============================] - 5s 46ms/step - loss: 2.9870 - masked_acc: 0.3950 - val_loss: 2.9956 - val_masked_acc: 0.3866
Epoch 38/100
 99/100 [============================>.] - ETA: 0s - loss: 2.9488 - masked_acc: 0.4022

a man in a red wetsuit is riding a wave
a man in a red and white and white dog is holding a wave
a man carrying a wave

100/100 [==============================] - 5s 50ms/step - loss: 2.9465 - masked_acc: 0.4024 - val_loss: 3.0461 - val_masked_acc: 0.3807
Epoch 39/100
100/100 [==============================] - ETA: 0s - loss: 2.8670 - masked_acc: 0.4077

a man in a red wetsuit is surfing
a surfer is riding a wave
a surfer is giving his goggles holding a surfboard

100/100 [==============================] - 5s 45ms/step - loss: 2.8670 - masked_acc: 0.4077 - val_loss: 3.0256 - val_masked_acc: 0.3890

Plot the loss and accuracy over the training run:

plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()
<matplotlib.legend.Legend at 0x7f78a83ec100>

png

plt.plot(history.history['masked_acc'], label='accuracy')
plt.plot(history.history['val_masked_acc'], label='val_accuracy')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()
<matplotlib.legend.Legend at 0x7f78a8296550>

png

Attention plots

Now, using the trained model, run that simple_gen method on the image:

result = model.simple_gen(image, temperature=0.0)
result
'a man in a red wetsuit is surfing'

Split the output back into tokens:

str_tokens = result.split()
str_tokens.append('[END]')

The DecoderLayers each cache the attention scores for their CrossAttention layer. The shape of each attention map is (batch=1, heads, sequence, image):

attn_maps = [layer.last_attention_scores for layer in model.decoder_layers]
[map.shape for map in attn_maps]
[TensorShape([1, 2, 9, 49]), TensorShape([1, 2, 9, 49])]

So stack the maps along the batch axis, then average over the (batch, heads) axes, while splitting the image axis back into height, width:

attention_maps = tf.concat(attn_maps, axis=0)
attention_maps = einops.reduce(
    attention_maps,
    'batch heads sequence (height width) -> sequence height width',
    height=7, width=7,
    reduction='mean')

Now you have a single attention map, for each sequence prediction. The values in each map should sum to 1.

einops.reduce(attention_maps, 'sequence height width -> sequence', reduction='sum')
<tf.Tensor: shape=(9,), dtype=float32, numpy=
array([1.       , 1.       , 0.9999999, 1.       , 1.       , 1.       ,

       1.       , 1.       , 1.       ], dtype=float32)>

So here is where the model was focusing attention while generating each token of the output:

def plot_attention_maps(image, str_tokens, attention_map):
    fig = plt.figure(figsize=(16, 9))

    len_result = len(str_tokens)

    titles = []
    for i in range(len_result):
      map = attention_map[i]
      grid_size = max(int(np.ceil(len_result/2)), 2)
      ax = fig.add_subplot(3, grid_size, i+1)
      titles.append(ax.set_title(str_tokens[i]))
      img = ax.imshow(image)
      ax.imshow(map, cmap='gray', alpha=0.6, extent=img.get_extent(),
                clim=[0.0, np.max(map)])

    plt.tight_layout()
plot_attention_maps(image/255, str_tokens, attention_maps)

png

Now put that together into a more usable function:

@Captioner.add_method
def run_and_show_attention(self, image, temperature=0.0):
  result_txt = self.simple_gen(image, temperature)
  str_tokens = result_txt.split()
  str_tokens.append('[END]')

  attention_maps = [layer.last_attention_scores for layer in self.decoder_layers]
  attention_maps = tf.concat(attention_maps, axis=0)
  attention_maps = einops.reduce(
      attention_maps,
      'batch heads sequence (height width) -> sequence height width',
      height=7, width=7,
      reduction='mean')

  plot_attention_maps(image/255, str_tokens, attention_maps)
  t = plt.suptitle(result_txt)
  t.set_y(1.05)
run_and_show_attention(model, image)

png

Try it on your own images

For fun, below you're provided a method you can use to caption your own images with the model you've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for strange results!)

image_url = 'https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg'
image_path = tf.keras.utils.get_file(origin=image_url)
image = load_image(image_path)

run_and_show_attention(model, image)
Downloading data from https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg
67460/67460 [==============================] - 0s 0us/step

png