Have a question? Connect with the community at the TensorFlow Forum Visit Forum

Decoding API

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Overview

In the recent past, there has been a lot of research in language generation with auto-regressive models. In auto-regressive language generation, the probability distribution of token at time step K is dependent on the model's token-predictions till step K-1. For these models, decoding strategies such as Beam search, Greedy, Top-p, and Top-k are critical components of the model and largely influence the style/nature of the generated output token at a given time step K.

For example, Beam search reduces the risk of missing hidden high probability tokens by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. Murray et al. (2018) and Yang et al. (2018) show that beam search works well in Machine Translation tasks. Both Beam search & Greedy strategies have a possibility of generating repeating tokens.

Fan et. al (2018) introduced Top-K sampling, in which the K most likely tokens are filtered and the probability mass is redistributed among only those K tokens.

Ari Holtzman et. al (2019) introduced Top-p sampling, which chooses from the smallest possible set of tokens with cumulative probability that adds upto the probability p. The probability mass is then redistributed among this set. This way, the size of the set of tokens can dynamically increase and decrease. Top-p, Top-k are generally used in tasks such as story-generation.

The Decoding API provides an interface to experiment with different decoding strategies on auto-regressive models.

  1. The following sampling strategies are provided in sampling_module.py, which inherits from the base Decoding class:

  2. Beam search is provided in beam_search.py. github

Setup

pip install -q -U tensorflow-text
pip install -q tf-models-nightly
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from official import nlp
from official.nlp.modeling.ops import sampling_module
from official.nlp.modeling.ops import beam_search

Initialize Sampling Module in TF-NLP.

  • symbols_to_logits_fn : Use this closure to call the model to predict the logits for the index+1 step. Inputs and outputs for this closure are as follows:
Args:
  1] ids : Current decoded sequences. int tensor with shape (batch_size, index + 1 or 1 if padded_decode is True)],
  2] index [scalar] : current decoded step,
  3] cache [nested dictionary of tensors] : Only used for faster decoding to store pre-computed attention hidden states for keys and values. More explanation in the cell below.
Returns:
  1] tensor for next-step logits [batch_size, vocab]
  2] the updated_cache [nested dictionary of tensors].

The cache is used for faster decoding. Here is a reference implementation for the above closure.

  • length_normalization_fn : Use this closure for returning length normalization parameter.
Args: 
  1] length : scalar for decoded step index.
  2] dtype : data-type of output tensor
Returns:
  1] value of length normalization factor.
  • vocab_size : Output vocabulary size.

  • max_decode_length : Scalar for total number of decoding steps.

  • eos_id : Decoding will stop if all output decoded ids in the batch have this eos_id.

  • padded_decode : Set this to True if running on TPU. Tensors are padded to max_decoding_length if this is True.

  • top_k : top_k is enabled if this value is > 1.

  • top_p : top_p is enabled if this value is > 0 and < 1.0

  • sampling_temperature : This is used to re-estimate the softmax output. Temperature skews the distribution towards high probability tokens and lowers the mass in tail distribution. Value has to be positive. Low temperature is equivalent to greedy and makes the distribution sharper, while high temperature makes it more flat.

  • enable_greedy : By default, this is True and greedy decoding is enabled. To experiment with other strategies, please set this to False.

Initialize the Model hyperparameters

params = {}
params['num_heads'] = 2
params['num_layers'] = 2
params['batch_size'] = 2
params['n_dims'] = 256
params['max_decode_length'] = 4

In auto-regressive architectures like Transformer based Encoder-Decoder models, Cache is used for fast sequential decoding. It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) for every layer.

Initialize cache.

cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", cache['layer_1']['k'].shape)
cache key shape for layer 1 : (2, 4, 2, 128)

Define closure for length normalization if needed.

This is used for normalizing the final scores of generated sequences and is optional

def length_norm(length, dtype):
  """Return length normalization factor."""
  return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)

Create model_fn

In practice, this will be replaced by an actual model implementation such as here

Args:
i : Step that is being decoded.
Returns:
  logit probabilities of size [batch_size, 1, vocab_size]
probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                            [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])
def model_fn(i):
  return probabilities[:, i, :]

Initialize symbols_to_logits_fn

def _symbols_to_logits_fn():
  """Calculates logits of the next tokens."""
  def symbols_to_logits_fn(ids, i, temp_cache):
    del ids
    logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)
    return logits, temp_cache
  return symbols_to_logits_fn

Greedy

Greedy decoding selects the token id with the highest probability as its next id: $id_t = argmax_{w}P(id | id_{1:t-1})$ at each timestep $t$. The following sketch shows greedy decoding.

greedy_obj = sampling_module.SamplingModule(
    length_normalization_fn=None,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False)
ids, _ = greedy_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("Greedy Decoded Ids:", ids)
Greedy Decoded Ids: tf.Tensor(
[[9 1 2 2 2]
 [1 1 1 2 2]], shape=(2, 5), dtype=int32)

top_k sampling

In Top-K sampling, the K most likely next token ids are filtered and the probability mass is redistributed among only those K ids.

top_k_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_k=tf.constant(3),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_k_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-k sampled Ids:", ids)
top-k sampled Ids: tf.Tensor(
[[9 2 2 2 2]
 [1 1 1 2 2]], shape=(2, 5), dtype=int32)

top_p sampling

Instead of sampling only from the most likely K token ids, in Top-p sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability p.

top_p_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_p=tf.constant(0.9),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_p_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-p sampled Ids:", ids)
top-p sampled Ids: tf.Tensor(
[[9 2 1 2 2]
 [1 0 1 2 2]], shape=(2, 5), dtype=int32)

Beam search decoding

Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability.

beam_size = 2
params['batch_size'] = 1
beam_cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", beam_cache['layer_1']['k'].shape)
ids, _ = beam_search.sequence_beam_search(
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    initial_ids=tf.constant([9], tf.int32),
    initial_cache=beam_cache,
    vocab_size=3,
    beam_size=beam_size,
    alpha=0.6,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False,
    dtype=tf.float32)
print("Beam search ids:", ids)
cache key shape for layer 1 : (1, 4, 2, 256)
Beam search ids: tf.Tensor(
[[[9 0 1 2 2]
  [9 1 2 2 2]]], shape=(1, 2, 5), dtype=int32)