tfm.nlp.ops.SamplingModule

Implementation for sampling strategies (go/decoding-tf-nlp).

Methods

generate

View source

Implements the decoding strategy (beam_search or sampling).

Args
initial_ids initial ids to pass into the symbols_to_logits_fn. int tensor with shape [batch_size, 1]
initial_cache dictionary for caching model outputs from previous step.
initial_log_probs Optionally initial log probs if there is a prefix sequence we want to start to decode from.

Returns
Tuple of tensors representing finished_sequence: shape [batch, max_seq_length] finished_scores: [batch] first_cache: The cache after init token

inf

View source

Returns a value close to infinity, but is still finite in dtype.

This is useful to get a very large value that is still zero when multiplied by zero. The floating-point "Inf" value is NaN when multiplied by zero.

Returns
A very large value.