tfm.nlp.ops.sequence_beam_search

Search for sequence of subtoken ids with the largest probability.

symbols_to_logits_fn A function that takes in ids, index, and cache as arguments. The passed in arguments will have shape: ids -> A tensor with shape [batch_size * beam_size, index]. index -> A scalar. cache -> A nested dictionary of tensors [batch_size * beam_size, ...]. The function must return a tuple of logits and new cache: logits -> A tensor with shape [batch * beam_size, vocab_size]. new cache -> A nested dictionary with the same shape/structure as the inputted cache.
initial_ids An int32 tensor with shape [batch_size]. Starting ids for each batch item.
initial_cache A dictionary, containing starting decoder variables information.
vocab_size An integer, the size of tokens.
beam_size An integer, the number of beams.
alpha A float, defining the strength of length normalization.
max_decode_length An integer, the maximum length to decoded a sequence.
eos_id An integer, ID of eos token, used to determine when a sequence has finished.
padded_decode A bool, indicating if max_sequence_length padding is used for beam search.
dtype A tensorflow data type used for score computation. The default is tf.float32.

Top decoded sequences [batch_size, beam_size, max_decode_length] sequence scores [batch_size, beam_size]