tfa.seq2seq.GreedyEmbeddingSampler

View source on GitHub

A sampler for use during inference.

Inherits From: Sampler

Used in the notebooks

Used in the tutorials

Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.

embedding_fn A optional callable that takes a vector tensor of ids (argmax ids). The returned tensor will be passed to the decoder input. Default to use tf.nn.embedding_lookup.

batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

sample_ids_dtype DType of tensor returned by sample.

Returns a DType. The return value might not available before the invocation of initialize().

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape. The return value might not available before the invocation of initialize().

Methods

initialize

View source

Initialize the GreedyEmbeddingSampler.

Args
embedding tensor that contains embedding states matrix. It will be used to generate generate outputs with start_tokens and end_tokens. The embedding will be ignored if the embedding_fn has been provided at init().
start_tokens int32 vector shaped [batch_size], the start tokens.
end_token int32 scalar, the token that marks end of decoding.

Returns
Tuple of two items: (finished, self.start_inputs).

Raises
ValueError if start_tokens is not a 1D tensor or end_token is not a scalar.

next_inputs

View source

next_inputs_fn for GreedyEmbeddingHelper.

sample

View source

sample for GreedyEmbeddingHelper.