Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

Implementation of beam search loop.

symbols_to_logits_fn A function to provide logits, which is the interface to the Transformer model. The passed in arguments are: ids -> A tensor with shape [batch_size * beam_size, index]. index -> A scalar. cache -> A nested dictionary of tensors [batch_size * beam_size, ...]. The function must return a tuple of logits and the updated cache: logits -> A tensor with shape [batch * beam_size, vocab_size]. updated cache -> A nested dictionary with the same structure as the input cache.
vocab_size An integer, the size of the vocabulary, used for topk computation.
beam_size An integer, number of beams for beam search.
alpha A float, defining the strength of length normalization.
max_decode_length An integer, the maximum number of steps to decode a sequence.
eos_id An integer. ID of end of sentence token.
padded_decode A bool, indicating if max_sequence_length padding is used for beam search.
dtype A tensorflow data type used for score computation. The default is tf.float32.
decoding_name an optional name for the decoding loop tensors.


View source

Beam search for sequences with highest scores.

initial_ids initial ids to pass into the symbols_to_logits_fn. int tensor with shape [batch_size, 1]
initial_cache dictionary storing values to be passed into the symbols_to_logits_fn.

finished_seq and finished_scores.