Image captioning with visual attention

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Given an image like the example below, your goal is to generate a caption such as "a surfer riding on a wave".

A man surfing, from wikimedia

The model architecture used here is inspired by Show, Attend and Tell: Neural Image Caption Generation with Visual Attention, but has been updated to use a 2-layer Transformer-decoder. To get the most out of this tutorial you should have some experience with text generation, seq2seq models & attention, or transformers.

The model architecture built in this tutorial is shown below. Features are extracted from the image, and passed to the cross-attention layers of the Transformer-decoder.

The model architecture

The transformer decoder is mainly built from attention layers. It uses self-attention to process the sequence being generated, and it uses cross-attention to attend to the image.

By inspecting the attention weights of the cross attention layers you will see what parts of the image the model is looking at as it generates words.

Prediction

This notebook is an end-to-end example. When you run the notebook, it downloads a dataset, extracts and caches the image features, and trains a decoder model. It then uses the model to generate captions on new images.

Setup

apt install --allow-change-held-packages libcudnn8=8.6.0.163-1+cuda11.8
E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?
pip uninstall -y tensorflow estimator keras
pip install -U tensorflow_text tensorflow tensorflow_datasets
pip install einops

This tutorial uses lots of imports, mostly for loading the dataset(s).

2023-11-16 12:20:52.474867: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-16 12:20:52.474916: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-16 12:20:52.476512: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

[Optional] Data handling

This section downloads a captions dataset and prepares it for training. It tokenizes the input text, and caches the results of running all the images through a pretrained feature-extractor model. It's not critical to understand everything in this section.

Data ready for training

After those preprocessing steps, here are the datasets:

train_ds = load_dataset('train_cache')
test_ds = load_dataset('test_cache')
train_ds.element_spec
((TensorSpec(shape=(None, 7, 7, 576), dtype=tf.float32, name=None),
  TensorSpec(shape=(None, None), dtype=tf.int64, name=None)),
 TensorSpec(shape=(None, None), dtype=tf.int64, name=None))

The dataset now returns (input, label) pairs suitable for training with keras. The inputs are (images, input_tokens) pairs. The images have been processed with the feature-extractor model. For each location in the input_tokens the model looks at the text so far and tries to predict the next which is lined up at the same location in the labels.

for (inputs, ex_labels) in train_ds.take(1):
  (ex_img, ex_in_tok) = inputs

print(ex_img.shape)
print(ex_in_tok.shape)
print(ex_labels.shape)
(32, 7, 7, 576)
(32, 23)
(32, 23)

The input tokens and the labels are the same, just shifted by 1 step:

print(ex_in_tok[0].numpy())
print(ex_labels[0].numpy())
[   3   14  446 1460   49    7    2   15  154 2064  205    0    0    0
    0    0    0    0    0    0    0    0    0]
[  14  446 1460   49    7    2   15  154 2064  205    4    0    0    0
    0    0    0    0    0    0    0    0    0]

A Transformer decoder model

This model assumes that the pretrained image encoder is sufficient, and just focuses on building the text decoder. This tutorial uses a 2-layer Transformer-decoder.

The implementations are almost identical to those in the Transformers tutorial. Refer back to it for more details.

The Transformer encoder and decoder.

The model will be implemented in three main parts:

  1. Input - The token embedding and positional encoding (SeqEmbedding).
  2. Decoder - A stack of transformer decoder layers (DecoderLayer) where each contains:
    1. A causal self attention later (CausalSelfAttention), where each output location can attend to the output so far.
    2. A cross attention layer (CrossAttention) where each output location can attend to the input image.
    3. A feed forward network (FeedForward) layer which further processes each output location independently.
  3. Output - A multiclass-classification over the output vocabulary.

Input

The input text has already been split up into tokens and converted to sequences of IDs.

Remember that unlike a CNN or RNN the Transformer's attention layers are invariant to the order of the sequence. Without some positional input, it just sees an unordered set not a sequence. So in addition to a simple vector embedding for each token ID, the embedding layer will also include an embedding for each position in the sequence.

The SeqEmbedding layer defined below:

  • It looks up the embedding vector for each token.
  • It looks up an embedding vector for each sequence location.
  • It adds the two together.
  • It uses mask_zero=True to initialize the keras-masks for the model.
class SeqEmbedding(tf.keras.layers.Layer):
  def __init__(self, vocab_size, max_length, depth):
    super().__init__()
    self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth)

    self.token_embedding = tf.keras.layers.Embedding(
        input_dim=vocab_size,
        output_dim=depth,
        mask_zero=True)

    self.add = tf.keras.layers.Add()

  def call(self, seq):
    seq = self.token_embedding(seq) # (batch, seq, depth)

    x = tf.range(tf.shape(seq)[1])  # (seq)
    x = x[tf.newaxis, :]  # (1, seq)
    x = self.pos_embedding(x)  # (1, seq, depth)

    return self.add([seq,x])

Decoder

The decoder is a standard Transformer-decoder, it contains a stack of DecoderLayers where each contains three sublayers: a CausalSelfAttention, a CrossAttention, and aFeedForward. The implementations are almost identical to the Transformer tutorial, refer to it for more details.

The CausalSelfAttention layer is below:

class CausalSelfAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    # Use Add instead of + so the keras mask propagates through.
    self.add = tf.keras.layers.Add() 
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    attn = self.mha(query=x, value=x,
                    use_causal_mask=True)
    x = self.add([x, attn])
    return self.layernorm(x)

The CrossAttention layer is below. Note the use of return_attention_scores.

class CrossAttention(tf.keras.layers.Layer):
  def __init__(self,**kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.add = tf.keras.layers.Add() 
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x, y, **kwargs):
    attn, attention_scores = self.mha(
             query=x, value=y,
             return_attention_scores=True)

    self.last_attention_scores = attention_scores

    x = self.add([x, attn])
    return self.layernorm(x)

The FeedForward layer is below. Remember that a layers.Dense layer is applied to the last axis of the input. The input will have a shape of (batch, sequence, channels), so it automatically applies pointwise across the batch and sequence axes.

class FeedForward(tf.keras.layers.Layer):
  def __init__(self, units, dropout_rate=0.1):
    super().__init__()
    self.seq = tf.keras.Sequential([
        tf.keras.layers.Dense(units=2*units, activation='relu'),
        tf.keras.layers.Dense(units=units),
        tf.keras.layers.Dropout(rate=dropout_rate),
    ])

    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    x = x + self.seq(x)
    return self.layernorm(x)

Next arrange these three layers into a larger DecoderLayer. Each decoder layer applies the three smaller layers in sequence. After each sublayer the shape of out_seq is (batch, sequence, channels). The decoder layer also returns the attention_scores for later visualizations.

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, units, num_heads=1, dropout_rate=0.1):
    super().__init__()

    self.self_attention = CausalSelfAttention(num_heads=num_heads,
                                              key_dim=units,
                                              dropout=dropout_rate)
    self.cross_attention = CrossAttention(num_heads=num_heads,
                                          key_dim=units,
                                          dropout=dropout_rate)
    self.ff = FeedForward(units=units, dropout_rate=dropout_rate)


  def call(self, inputs, training=False):
    in_seq, out_seq = inputs

    # Text input
    out_seq = self.self_attention(out_seq)

    out_seq = self.cross_attention(out_seq, in_seq)

    self.last_attention_scores = self.cross_attention.last_attention_scores

    out_seq = self.ff(out_seq)

    return out_seq

Output

At minimum the output layer needs a layers.Dense layer to generate logit-predictions for each token at each location.

But there are a few other features you can add to make this work a little better:

  1. Handle bad tokens: The model will be generating text. It should never generate a pad, unknown, or start token ('', '[UNK]', '[START]'). So set the bias for these to a large negative value.

  2. Smart initialization: The default initialization of a dense layer will give a model that initially predicts each token with almost uniform likelihood. The actual token distribution is far from uniform. The optimal value for the initial bias of the output layer is the log of the probability of each token. So include an adapt method to count the tokens and set the optimal initial bias. This reduces the initial loss from the entropy of the uniform distribution (log(vocabulary_size)) to the marginal entropy of the distribution (-p*log(p)).

The smart initialization will significantly reduce the initial loss:

output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]'))
# This might run a little faster if the dataset didn't also have to load the image data.
output_layer.adapt(train_ds.map(lambda inputs, labels: labels))
100%|██████████| 938/938 [00:02<00:00, 357.06it/s]
Uniform entropy: 8.52
Marginal entropy: 5.29

Build the model

To build the model, you need to combine several parts:

  1. The image feature_extractor and the text tokenizer and.
  2. The seq_embedding layer, to convert batches of token-IDs to vectors (batch, sequence, channels).
  3. The stack of DecoderLayers layers that will process the text and image data.
  4. The output_layer which returns a pointwise prediction of what the next word should be.
class Captioner(tf.keras.Model):
  @classmethod
  def add_method(cls, fun):
    setattr(cls, fun.__name__, fun)
    return fun

  def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1,
               units=256, max_length=50, num_heads=1, dropout_rate=0.1):
    super().__init__()
    self.feature_extractor = feature_extractor
    self.tokenizer = tokenizer
    self.word_to_index = tf.keras.layers.StringLookup(
        mask_token="",
        vocabulary=tokenizer.get_vocabulary())
    self.index_to_word = tf.keras.layers.StringLookup(
        mask_token="",
        vocabulary=tokenizer.get_vocabulary(),
        invert=True) 

    self.seq_embedding = SeqEmbedding(
        vocab_size=tokenizer.vocabulary_size(),
        depth=units,
        max_length=max_length)

    self.decoder_layers = [
        DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate)
        for n in range(num_layers)]

    self.output_layer = output_layer

When you call the model, for training, it receives an image, txt pair. To make this function more usable, be flexible about the input:

  • If the image has 3 channels run it through the feature_extractor. Otherwise assume that it has been already. Similarly
  • If the text has dtype tf.string run it through the tokenizer.

After that running the model is only a few steps:

  1. Flatten the extracted image features, so they can be input to the decoder layers.
  2. Look up the token embeddings.
  3. Run the stack of DecoderLayers, on the image features and text embeddings.
  4. Run the output layer to predict the next token at each position.
@Captioner.add_method
  def call(self, inputs):
    image, txt = inputs

    if image.shape[-1] == 3:
      # Apply the feature-extractor, if you get an RGB image.
      image = self.feature_extractor(image)

    # Flatten the feature map
    image = einops.rearrange(image, 'b h w c -> b (h w) c')


    if txt.dtype == tf.string:
      # Apply the tokenizer if you get string inputs.
      txt = tokenizer(txt)

    txt = self.seq_embedding(txt)

    # Look at the image
    for dec_layer in self.decoder_layers:
      txt = dec_layer(inputs=(image, txt))

    txt = self.output_layer(txt)

    return txt
model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer,
                  units=256, dropout_rate=0.5, num_layers=2, num_heads=2)

Generate captions

Before getting into training, write a bit of code to generate captions. You'll use this to see how training is progressing.

Start by downloading a test image:

image_url = 'https://tensorflow.org/images/surf.jpg'
image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
image = load_image(image_path)
Downloading data from https://tensorflow.org/images/surf.jpg
64400/64400 [==============================] - 0s 1us/step

To caption an image with this model:

  • Extract the img_features
  • Initialize the list of output tokens with a [START] token.
  • Pass img_features and tokens into the model.
    • It returns a list of logits.
    • Choose the next token based on those logits.
    • Add it to the list of tokens, and continue the loop.
    • If it generates an '[END]' token, break out of the loop.

So add a "simple" method to do just that:

@Captioner.add_method
def simple_gen(self, image, temperature=1):
  initial = self.word_to_index([['[START]']]) # (batch, sequence)
  img_features = self.feature_extractor(image[tf.newaxis, ...])

  tokens = initial # (batch, sequence)
  for n in range(50):
    preds = self((img_features, tokens)).numpy()  # (batch, sequence, vocab)
    preds = preds[:,-1, :]  #(batch, vocab)
    if temperature==0:
        next = tf.argmax(preds, axis=-1)[:, tf.newaxis]  # (batch, 1)
    else:
        next = tf.random.categorical(preds/temperature, num_samples=1)  # (batch, 1)
    tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) 

    if next[0] == self.word_to_index('[END]'):
      break
  words = index_to_word(tokens[0, 1:-1])
  result = tf.strings.reduce_join(words, axis=-1, separator=' ')
  return result.numpy().decode()

Here are some generated captions for that image, the model's untrained, so they don't make much sense yet:

for t in (0.0, 0.5, 1.0):
  result = model.simple_gen(image, temperature=t)
  print(result)
a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a
a of
skateboard looking out dogs a white on in faces with sitting is black skateboarder tossed skateboard do a in two his surf white one jacket dogs across play

The temperature parameter allows you to interpolate between 3 modes:

  1. Greedy decoding (temperature=0.0) - Chooses the most likely next token at each step.
  2. Random sampling according to the logits (temperature=1.0).
  3. Uniform random sampling (temperature >> 1.0).

Since the model is untrained, and it used the frequency-based initialization, the "greedy" output (first) usually only contains the most common tokens: ['a', '.', '[END]'].

Train

To train the model you'll need several additional components:

  • The Loss and metrics
  • The Optimizer
  • Optional Callbacks

Losses and metrics

Here's an implementation of a masked loss and accuracy:

When calculating the mask for the loss, note the loss < 1e8. This term discards the artificial, impossibly high losses for the banned_tokens.

def masked_loss(labels, preds):  
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)

  mask = (labels != 0) & (loss < 1e8) 
  mask = tf.cast(mask, loss.dtype)

  loss = loss*mask
  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss

def masked_acc(labels, preds):
  mask = tf.cast(labels!=0, tf.float32)
  preds = tf.argmax(preds, axis=-1)
  labels = tf.cast(labels, tf.int64)
  match = tf.cast(preds == labels, mask.dtype)
  acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask)
  return acc

Callbacks

For feedback during training setup a keras.callbacks.Callback to generate some captions for the surfer image at the end of each epoch.

class GenerateText(tf.keras.callbacks.Callback):
  def __init__(self):
    image_url = 'https://tensorflow.org/images/surf.jpg'
    image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
    self.image = load_image(image_path)

  def on_epoch_end(self, epochs=None, logs=None):
    print()
    print()
    for t in (0.0, 0.5, 1.0):
      result = self.model.simple_gen(self.image, temperature=t)
      print(result)
    print()

It generates three output strings, like the earlier example, like before the first is "greedy", choosing the argmax of the logits at each step.

g = GenerateText()
g.model = model
g.on_epoch_end(0)
a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a
two through a a a in a
raft hockey a food pushed holding kicking two a

Also use callbacks.EarlyStopping to terminate training when the model starts to overfit.

callbacks = [
    GenerateText(),
    tf.keras.callbacks.EarlyStopping(
        patience=5, restore_best_weights=True)]

Train

Configure and execute the training.

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
           loss=masked_loss,
           metrics=[masked_acc])

For more frequent reporting, use the Dataset.repeat() method, and set the steps_per_epoch and validation_steps arguments to Model.fit.

With this setup on Flickr8k a full pass over the dataset is 900+ batches, but below the reporting-epochs are 100 steps.

history = model.fit(
    train_ds.repeat(),
    steps_per_epoch=100,
    validation_data=test_ds.repeat(),
    validation_steps=20,
    epochs=100,
    callbacks=callbacks)
Epoch 1/100
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700137325.432494   23468 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
99/100 [============================>.] - ETA: 0s - loss: 5.0028 - masked_acc: 0.2014

a man in a man in a
a man in a woman a in a
a boy with the horizontal

100/100 [==============================] - 21s 110ms/step - loss: 5.0010 - masked_acc: 0.2017 - val_loss: 4.5998 - val_masked_acc: 0.2409
Epoch 2/100
 98/100 [============================>.] - ETA: 0s - loss: 4.5994 - masked_acc: 0.2569

a man in a white dog is in the water
a man is is in the in the brown and a
a man in a the the performs the dirt in the white groom table

100/100 [==============================] - 7s 67ms/step - loss: 4.5971 - masked_acc: 0.2573 - val_loss: 4.3980 - val_masked_acc: 0.2662
Epoch 3/100
 99/100 [============================>.] - ETA: 0s - loss: 4.4053 - masked_acc: 0.2733

a man in a red and a red and a red and a red and white dog is in the water
a man in a yellow shirt on a red and its in the water
a person sits while a yellow water on man in the soccer arms

100/100 [==============================] - 6s 65ms/step - loss: 4.3997 - masked_acc: 0.2739 - val_loss: 4.1910 - val_masked_acc: 0.2917
Epoch 4/100
 99/100 [============================>.] - ETA: 0s - loss: 4.2159 - masked_acc: 0.2962

a man in a red shirt is in the water
a man in a red yellow dog in the water
two child playing on the camera

100/100 [==============================] - 6s 57ms/step - loss: 4.2136 - masked_acc: 0.2963 - val_loss: 4.0654 - val_masked_acc: 0.3002
Epoch 5/100
100/100 [==============================] - ETA: 0s - loss: 4.1133 - masked_acc: 0.3101

a man in a red shirt is in the water
a man is running on a a ramp
a man rides a look in a one flips to

100/100 [==============================] - 5s 55ms/step - loss: 4.1133 - masked_acc: 0.3101 - val_loss: 3.9278 - val_masked_acc: 0.3084
Epoch 6/100
 99/100 [============================>.] - ETA: 0s - loss: 4.0273 - masked_acc: 0.3172

a man in a red shirt is jumping in the water
a man in the water
people in a large football swimming stand while a carry in a headphones

100/100 [==============================] - 5s 55ms/step - loss: 4.0270 - masked_acc: 0.3175 - val_loss: 3.8928 - val_masked_acc: 0.3188
Epoch 7/100
100/100 [==============================] - ETA: 0s - loss: 3.9356 - masked_acc: 0.3272

a man in a red shirt is in the water
a man in a blue pool
a man shirts in the stands in a snowy stop helmet holding a small boys while a beach

100/100 [==============================] - 6s 57ms/step - loss: 3.9356 - masked_acc: 0.3272 - val_loss: 3.8244 - val_masked_acc: 0.3310
Epoch 8/100
 99/100 [============================>.] - ETA: 0s - loss: 3.8725 - masked_acc: 0.3301

a man in a blue shirt is running in the water
a man in a blue shirt is wearing a blue shirt is running in the water
a person wearing in a surfboard has a to in swing in the with a outdoors

100/100 [==============================] - 6s 62ms/step - loss: 3.8699 - masked_acc: 0.3298 - val_loss: 3.7612 - val_masked_acc: 0.3305
Epoch 9/100
100/100 [==============================] - ETA: 0s - loss: 3.8222 - masked_acc: 0.3305

a man in a blue shirt is jumping in the water
a man in the water
a small boy in the through a cardboard valley

100/100 [==============================] - 5s 51ms/step - loss: 3.8222 - masked_acc: 0.3305 - val_loss: 3.6666 - val_masked_acc: 0.3392
Epoch 10/100
100/100 [==============================] - ETA: 0s - loss: 3.7326 - masked_acc: 0.3388

a man in a blue shirt is jumping in the water
a man is looks on a swimming pool
a catches a blue river off of a water with a tripod beak leaping

100/100 [==============================] - 6s 58ms/step - loss: 3.7326 - masked_acc: 0.3388 - val_loss: 3.5929 - val_masked_acc: 0.3495
Epoch 11/100
 98/100 [============================>.] - ETA: 0s - loss: 3.6332 - masked_acc: 0.3452

a man in a blue shirt is jumping in the water
a man in a blue is in a blue shirt is swimming pool
a person a girl in girl gallery

100/100 [==============================] - 6s 56ms/step - loss: 3.6333 - masked_acc: 0.3451 - val_loss: 3.5513 - val_masked_acc: 0.3508
Epoch 12/100
 99/100 [============================>.] - ETA: 0s - loss: 3.6076 - masked_acc: 0.3484

a man in a red shirt is jumping into the water
a person in a blue shirt is in the snow
the boy is an orange backpack in the wooded standing in the pool

100/100 [==============================] - 6s 56ms/step - loss: 3.6080 - masked_acc: 0.3483 - val_loss: 3.4182 - val_masked_acc: 0.3596
Epoch 13/100
100/100 [==============================] - ETA: 0s - loss: 3.5863 - masked_acc: 0.3506

a man in a red shirt is jumping into the water
a man in a red shirt is jumping into the water
the girl plays in a blue blue water

100/100 [==============================] - 5s 53ms/step - loss: 3.5863 - masked_acc: 0.3506 - val_loss: 3.4612 - val_masked_acc: 0.3501
Epoch 14/100
 98/100 [============================>.] - ETA: 0s - loss: 3.5353 - masked_acc: 0.3557

a man in a red shirt is jumping into the water
a man in a yellow shirt is jumping over a wave
a pool are surfing a grassy wave while being picture of a stick in the track

100/100 [==============================] - 6s 58ms/step - loss: 3.5324 - masked_acc: 0.3563 - val_loss: 3.4347 - val_masked_acc: 0.3573
Epoch 15/100
100/100 [==============================] - ETA: 0s - loss: 3.4977 - masked_acc: 0.3581

a man in a red shirt is jumping into the water
a young boy is doing a wave
a man is wearing a blue pants riding orange against a pool

100/100 [==============================] - 5s 53ms/step - loss: 3.4977 - masked_acc: 0.3581 - val_loss: 3.4759 - val_masked_acc: 0.3528
Epoch 16/100
 99/100 [============================>.] - ETA: 0s - loss: 3.4855 - masked_acc: 0.3578

a man in a blue shirt is in the water
a man is in a blue jacket
two adults playing in the water

100/100 [==============================] - 5s 49ms/step - loss: 3.4842 - masked_acc: 0.3579 - val_loss: 3.3827 - val_masked_acc: 0.3597
Epoch 17/100
 98/100 [============================>.] - ETA: 0s - loss: 3.4446 - masked_acc: 0.3613

a man in a blue shirt is jumping into the water
a man sits on the water
a couple riding the wave

100/100 [==============================] - 5s 47ms/step - loss: 3.4456 - masked_acc: 0.3611 - val_loss: 3.4095 - val_masked_acc: 0.3494
Epoch 18/100
100/100 [==============================] - ETA: 0s - loss: 3.3869 - masked_acc: 0.3632

a man in a red shirt is jumping into the water
a man in a wave
a is grey dog is pass pool lake

100/100 [==============================] - 5s 50ms/step - loss: 3.3869 - masked_acc: 0.3632 - val_loss: 3.2843 - val_masked_acc: 0.3672
Epoch 19/100
 98/100 [============================>.] - ETA: 0s - loss: 3.3967 - masked_acc: 0.3616

a man in a blue shirt is jumping into the water
a person in a blue shirt is jumping into a pool
a man in jeans is walking a frisbee in a water surfboard

100/100 [==============================] - 5s 55ms/step - loss: 3.3925 - masked_acc: 0.3623 - val_loss: 3.4035 - val_masked_acc: 0.3525
Epoch 20/100
 98/100 [============================>.] - ETA: 0s - loss: 3.3193 - masked_acc: 0.3707

a man in a red shirt is jumping into the water
a man in a red shirt is swimming pool
this two people are playing in the sleeves is him

100/100 [==============================] - 5s 53ms/step - loss: 3.3193 - masked_acc: 0.3703 - val_loss: 3.2862 - val_masked_acc: 0.3624
Epoch 21/100
 98/100 [============================>.] - ETA: 0s - loss: 3.2601 - masked_acc: 0.3770

a man in a red jacket is in the water
a man in a helmet is swimming pool
a doing a s being throwing solitary underneath water off of

100/100 [==============================] - 5s 53ms/step - loss: 3.2581 - masked_acc: 0.3767 - val_loss: 3.1930 - val_masked_acc: 0.3802
Epoch 22/100
 98/100 [============================>.] - ETA: 0s - loss: 3.2383 - masked_acc: 0.3797

a man in a red shirt is swimming pool
a man in a red shirt is swimming pool
a person smoking a surfboard in the water

100/100 [==============================] - 5s 52ms/step - loss: 3.2419 - masked_acc: 0.3790 - val_loss: 3.2384 - val_masked_acc: 0.3755
Epoch 23/100
 98/100 [============================>.] - ETA: 0s - loss: 3.2583 - masked_acc: 0.3730

a man in a red shirt is swimming pool
a girl in a red shirt is jumping into the water
a man in a red is in the trees

100/100 [==============================] - 5s 52ms/step - loss: 3.2562 - masked_acc: 0.3735 - val_loss: 3.2171 - val_masked_acc: 0.3659
Epoch 24/100
 99/100 [============================>.] - ETA: 0s - loss: 3.2180 - masked_acc: 0.3788

a man in a red shirt is swimming pool
a man in a red dress is swimming pool
a boy is swimming splashing in kayak

100/100 [==============================] - 5s 49ms/step - loss: 3.2206 - masked_acc: 0.3788 - val_loss: 3.2413 - val_masked_acc: 0.3650
Epoch 25/100
100/100 [==============================] - ETA: 0s - loss: 3.2149 - masked_acc: 0.3764

a man in a red jacket is riding a wave
a man in a red life jacket is being wave
a snowboarder poses in an orange wetsuit of a crown wave

100/100 [==============================] - 5s 53ms/step - loss: 3.2149 - masked_acc: 0.3764 - val_loss: 3.1733 - val_masked_acc: 0.3686
Epoch 26/100
100/100 [==============================] - ETA: 0s - loss: 3.1716 - masked_acc: 0.3861

a man in a red shirt is jumping into the water
a man in a blue suit is riding a wave
a man greyhound and the white and white played down a surfer riding people stand in the water

100/100 [==============================] - 6s 59ms/step - loss: 3.1716 - masked_acc: 0.3861 - val_loss: 3.0892 - val_masked_acc: 0.3820
Epoch 27/100
 99/100 [============================>.] - ETA: 0s - loss: 3.1653 - masked_acc: 0.3797

a man in a blue shirt is splashing in the water
a man in a pink shirt is jumping over a red wave
a young girl in a wave off a splash of a wave

100/100 [==============================] - 6s 55ms/step - loss: 3.1631 - masked_acc: 0.3797 - val_loss: 3.1378 - val_masked_acc: 0.3753
Epoch 28/100
100/100 [==============================] - ETA: 0s - loss: 3.1389 - masked_acc: 0.3861

a man in a red shirt is jumping into the water
a man is riding a wave
a man in yellow bathing suit with snow covered jeep

100/100 [==============================] - 5s 51ms/step - loss: 3.1389 - masked_acc: 0.3861 - val_loss: 3.1138 - val_masked_acc: 0.3809
Epoch 29/100
 99/100 [============================>.] - ETA: 0s - loss: 3.1237 - masked_acc: 0.3851

a man in a blue wetsuit is swimming pool
a man in a yellow surfboard
a man in a orange hat doing a in that during a body of the swimming pool with the waves

100/100 [==============================] - 6s 57ms/step - loss: 3.1235 - masked_acc: 0.3848 - val_loss: 3.1013 - val_masked_acc: 0.3758
Epoch 30/100
100/100 [==============================] - ETA: 0s - loss: 3.0536 - masked_acc: 0.3922

a man in a blue shirt is in a wave
a boy in a blue swimming pool
a young child is splashing in surfboard after a tan gear and digging in a person with his mouth

100/100 [==============================] - 6s 59ms/step - loss: 3.0536 - masked_acc: 0.3922 - val_loss: 3.0286 - val_masked_acc: 0.3932
Epoch 31/100
 99/100 [============================>.] - ETA: 0s - loss: 3.0457 - masked_acc: 0.3944

a man in a red shirt is swimming in the water
a man in a yellow shirt is going down a wave
a large lone man wearing a red stick in pink the ocean

100/100 [==============================] - 6s 57ms/step - loss: 3.0461 - masked_acc: 0.3946 - val_loss: 3.1659 - val_masked_acc: 0.3727
Epoch 32/100
 98/100 [============================>.] - ETA: 0s - loss: 3.0094 - masked_acc: 0.3935

a man in a red shirt is swimming in the ocean
a man in a life jacket is looking into the water
a surfer does a orange hat is riding a wave

100/100 [==============================] - 5s 53ms/step - loss: 3.0104 - masked_acc: 0.3933 - val_loss: 3.0844 - val_masked_acc: 0.3816
Epoch 33/100
 98/100 [============================>.] - ETA: 0s - loss: 3.0153 - masked_acc: 0.3932

a man in a yellow shirt is riding a wave
a man in a yellow kayak is wave
a woman in a wave as a surfer hanging the wave

100/100 [==============================] - 5s 52ms/step - loss: 3.0117 - masked_acc: 0.3934 - val_loss: 3.0415 - val_masked_acc: 0.3885
Epoch 34/100
 99/100 [============================>.] - ETA: 0s - loss: 3.0083 - masked_acc: 0.3977

a man in a red shirt is in a wave
a man in a yellow shirt is riding a surfboard
a man in a red wave

100/100 [==============================] - 5s 52ms/step - loss: 3.0061 - masked_acc: 0.3979 - val_loss: 3.0120 - val_masked_acc: 0.3901
Epoch 35/100
 98/100 [============================>.] - ETA: 0s - loss: 3.0024 - masked_acc: 0.3958

a man in a red shirt is surfing
a surfer rides a wave on a wave
a girl in a wetsuit is in a deep rapids

100/100 [==============================] - 5s 51ms/step - loss: 3.0009 - masked_acc: 0.3961 - val_loss: 2.9933 - val_masked_acc: 0.3956
Epoch 36/100
 99/100 [============================>.] - ETA: 0s - loss: 2.9766 - masked_acc: 0.4011

a man in a blue wetsuit is surfing
a man in a wetsuit is surfing in a wave
a rider is flying through the water

100/100 [==============================] - 5s 49ms/step - loss: 2.9778 - masked_acc: 0.4008 - val_loss: 3.0575 - val_masked_acc: 0.3884
Epoch 37/100
 99/100 [============================>.] - ETA: 0s - loss: 2.9700 - masked_acc: 0.3969

a man in a red shirt is surfing
a man in a helmet is splashing in the ocean
a mascot in a blue bathing pink pink shirt is wearing a red wetsuit

100/100 [==============================] - 5s 54ms/step - loss: 2.9713 - masked_acc: 0.3966 - val_loss: 3.0061 - val_masked_acc: 0.3910
Epoch 38/100
 99/100 [============================>.] - ETA: 0s - loss: 2.9407 - masked_acc: 0.3995

a man in a red shirt is surfing a wave
a man in a red jacket is jumping up a surfboard
a boy is is gets newspaper

100/100 [==============================] - 5s 52ms/step - loss: 2.9398 - masked_acc: 0.3991 - val_loss: 3.0022 - val_masked_acc: 0.3879
Epoch 39/100
 99/100 [============================>.] - ETA: 0s - loss: 2.8727 - masked_acc: 0.4091

a man in a red shirt is surfing a wave
a man in a red vest is riding a wave
a man wearing a red life jumps into a cliff

100/100 [==============================] - 5s 53ms/step - loss: 2.8710 - masked_acc: 0.4092 - val_loss: 2.9745 - val_masked_acc: 0.3874
Epoch 40/100
 98/100 [============================>.] - ETA: 0s - loss: 2.8862 - masked_acc: 0.4068

a man in a red shirt is surfing a wave
a man in a wetsuit is wearing a wave
a man and a man is riding a skateboard

100/100 [==============================] - 5s 51ms/step - loss: 2.8874 - masked_acc: 0.4064 - val_loss: 3.0524 - val_masked_acc: 0.3789
Epoch 41/100
 98/100 [============================>.] - ETA: 0s - loss: 2.8682 - masked_acc: 0.4089

a man in a red wetsuit is surfing
a person in a red kayak is jumping into a pool
a child is riding a wave in the ocean

100/100 [==============================] - 5s 51ms/step - loss: 2.8672 - masked_acc: 0.4093 - val_loss: 2.9746 - val_masked_acc: 0.3887
Epoch 42/100
100/100 [==============================] - ETA: 0s - loss: 2.8700 - masked_acc: 0.4101

a man in a red shirt is surfing on a wave
a man in a red hat is doing a wave
a man wearing a blue shirt is skateboarding

100/100 [==============================] - 5s 52ms/step - loss: 2.8700 - masked_acc: 0.4101 - val_loss: 3.0025 - val_masked_acc: 0.3923
Epoch 43/100
 99/100 [============================>.] - ETA: 0s - loss: 2.8751 - masked_acc: 0.4091

a man in a red wetsuit is surfing
a man in a red wetsuit is surfing a wave
a man wearing a red shirt geyser

100/100 [==============================] - 5s 49ms/step - loss: 2.8735 - masked_acc: 0.4095 - val_loss: 3.0148 - val_masked_acc: 0.3762
Epoch 44/100
100/100 [==============================] - ETA: 0s - loss: 2.8501 - masked_acc: 0.4093

a man in a red wetsuit is surfing
a man in a red life jacket is surfing
a man is swinging on the surfboard

100/100 [==============================] - 5s 50ms/step - loss: 2.8501 - masked_acc: 0.4093 - val_loss: 2.9621 - val_masked_acc: 0.3922
Epoch 45/100
 98/100 [============================>.] - ETA: 0s - loss: 2.8547 - masked_acc: 0.4087

a surfer in a red wetsuit is surfing
a surfer is in a wave
a surfer surfer in a wave

100/100 [==============================] - 5s 48ms/step - loss: 2.8574 - masked_acc: 0.4081 - val_loss: 2.9064 - val_masked_acc: 0.3952
Epoch 46/100
 98/100 [============================>.] - ETA: 0s - loss: 2.8426 - masked_acc: 0.4115

a man in a red wetsuit is surfing
a surfer in a red surfing in the water
a person in a red wetsuit crouches across the ocean

100/100 [==============================] - 5s 53ms/step - loss: 2.8438 - masked_acc: 0.4115 - val_loss: 2.9418 - val_masked_acc: 0.3855
Epoch 47/100
100/100 [==============================] - ETA: 0s - loss: 2.8170 - masked_acc: 0.4116

a man in a red wetsuit is surfing
a surfer is surfing on a wave
men sitting on his back board up against a surfboard

100/100 [==============================] - 5s 50ms/step - loss: 2.8170 - masked_acc: 0.4116 - val_loss: 2.9876 - val_masked_acc: 0.3937
Epoch 48/100
 99/100 [============================>.] - ETA: 0s - loss: 2.7442 - masked_acc: 0.4178

a man in a red wetsuit is surfing a wave
a man in a red hat is surfing
a person in a yellow blue yellow surfboard in the ocean

100/100 [==============================] - 5s 55ms/step - loss: 2.7423 - masked_acc: 0.4181 - val_loss: 2.8305 - val_masked_acc: 0.4088
Epoch 49/100
 98/100 [============================>.] - ETA: 0s - loss: 2.7574 - masked_acc: 0.4200

a man in a red wetsuit is surfing on a wave
a man in a red wetsuit is jumping up in the ocean
snow surfing in a red surfboard

100/100 [==============================] - 5s 52ms/step - loss: 2.7588 - masked_acc: 0.4197 - val_loss: 2.8785 - val_masked_acc: 0.3984
Epoch 50/100
 98/100 [============================>.] - ETA: 0s - loss: 2.7467 - masked_acc: 0.4176

a man in a red wetsuit is surfing
a surfer rides a wave in a wave
a surfer along a boat one leg

100/100 [==============================] - 5s 49ms/step - loss: 2.7484 - masked_acc: 0.4175 - val_loss: 2.8792 - val_masked_acc: 0.4074
Epoch 51/100
 99/100 [============================>.] - ETA: 0s - loss: 2.7406 - masked_acc: 0.4178

a man in a red wetsuit is surfing
a person is surfing in a wave
a surfer rides his head wave over a wave

100/100 [==============================] - 5s 50ms/step - loss: 2.7394 - masked_acc: 0.4178 - val_loss: 2.8424 - val_masked_acc: 0.3982
Epoch 52/100
 98/100 [============================>.] - ETA: 0s - loss: 2.7206 - masked_acc: 0.4225

a man in a red wetsuit is surfing on a wave
a man in a wetsuit is surfing on a wave
a man in a orange shirt balances a in the air above a wave

100/100 [==============================] - 6s 56ms/step - loss: 2.7190 - masked_acc: 0.4229 - val_loss: 2.9031 - val_masked_acc: 0.3995
Epoch 53/100
 98/100 [============================>.] - ETA: 0s - loss: 2.7171 - masked_acc: 0.4242

a surfer in a red wetsuit is surfing
a person in a red wetsuit is surfing on a wave
a adult child with a wetsuit in a blue water

100/100 [==============================] - 5s 54ms/step - loss: 2.7187 - masked_acc: 0.4236 - val_loss: 2.8908 - val_masked_acc: 0.3952

Plot the loss and accuracy over the training run:

plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()
<matplotlib.legend.Legend at 0x7f4a244412e0>

png

plt.plot(history.history['masked_acc'], label='accuracy')
plt.plot(history.history['val_masked_acc'], label='val_accuracy')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()
<matplotlib.legend.Legend at 0x7f4a2428b640>

png

Attention plots

Now, using the trained model, run that simple_gen method on the image:

result = model.simple_gen(image, temperature=0.0)
result
'a man in a red wetsuit is surfing a wave'

Split the output back into tokens:

str_tokens = result.split()
str_tokens.append('[END]')

The DecoderLayers each cache the attention scores for their CrossAttention layer. The shape of each attention map is (batch=1, heads, sequence, image):

attn_maps = [layer.last_attention_scores for layer in model.decoder_layers]
[map.shape for map in attn_maps]
[TensorShape([1, 2, 11, 49]), TensorShape([1, 2, 11, 49])]

So stack the maps along the batch axis, then average over the (batch, heads) axes, while splitting the image axis back into height, width:

attention_maps = tf.concat(attn_maps, axis=0)
attention_maps = einops.reduce(
    attention_maps,
    'batch heads sequence (height width) -> sequence height width',
    height=7, width=7,
    reduction='mean')

Now you have a single attention map, for each sequence prediction. The values in each map should sum to 1.

einops.reduce(attention_maps, 'sequence height width -> sequence', reduction='sum')
<tf.Tensor: shape=(11,), dtype=float32, numpy=
array([1.        , 0.99999994, 1.        , 1.        , 1.        ,

       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        ], dtype=float32)>

So here is where the model was focusing attention while generating each token of the output:

def plot_attention_maps(image, str_tokens, attention_map):
    fig = plt.figure(figsize=(16, 9))

    len_result = len(str_tokens)

    titles = []
    for i in range(len_result):
      map = attention_map[i]
      grid_size = max(int(np.ceil(len_result/2)), 2)
      ax = fig.add_subplot(3, grid_size, i+1)
      titles.append(ax.set_title(str_tokens[i]))
      img = ax.imshow(image)
      ax.imshow(map, cmap='gray', alpha=0.6, extent=img.get_extent(),
                clim=[0.0, np.max(map)])

    plt.tight_layout()
plot_attention_maps(image/255, str_tokens, attention_maps)

png

Now put that together into a more usable function:

@Captioner.add_method
def run_and_show_attention(self, image, temperature=0.0):
  result_txt = self.simple_gen(image, temperature)
  str_tokens = result_txt.split()
  str_tokens.append('[END]')

  attention_maps = [layer.last_attention_scores for layer in self.decoder_layers]
  attention_maps = tf.concat(attention_maps, axis=0)
  attention_maps = einops.reduce(
      attention_maps,
      'batch heads sequence (height width) -> sequence height width',
      height=7, width=7,
      reduction='mean')

  plot_attention_maps(image/255, str_tokens, attention_maps)
  t = plt.suptitle(result_txt)
  t.set_y(1.05)
run_and_show_attention(model, image)

png

Try it on your own images

For fun, below you're provided a method you can use to caption your own images with the model you've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for strange results!)

image_url = 'https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg'
image_path = tf.keras.utils.get_file(origin=image_url)
image = load_image(image_path)

run_and_show_attention(model, image)
Downloading data from https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg
67460/67460 [==============================] - 0s 1us/step

png