TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfa.seq2seq.GreedyEmbeddingSampler

View source on GitHub

Class GreedyEmbeddingSampler

A sampler for use during inference.

Inherits From: Sampler

Aliases:

Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.

__init__

View source

__init__(embedding_fn=None)

Initializer.

Args:

  • embedding_fn: A optional callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup. The returned tensor will be passed to the decoder input. Default to use tf.nn.embedding_lookup.

Properties

batch_size

Batch size of tensor returned by sample.

Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

sample_ids_dtype

DType of tensor returned by sample.

Returns a DType. The return value might not available before the invocation of initialize().

sample_ids_shape

Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape. The return value might not available before the invocation of initialize().

Methods

initialize

View source

initialize(
    embedding,
    start_tokens=None,
    end_token=None
)

Initialize the GreedyEmbeddingSampler.

Args:

  • embedding: tensor that contains embedding states matrix. It will be used to generate generate outputs with start_tokens and end_tokens. The embedding will be ignored if the embedding_fn has been provided at init().
  • start_tokens: int32 vector shaped [batch_size], the start tokens.
  • end_token: int32 scalar, the token that marks end of decoding.

Returns:

Tuple of two items: (finished, self.start_inputs).

Raises:

  • ValueError: if start_tokens is not a 1D tensor or end_token is not a scalar.

next_inputs

View source

next_inputs(
    time,
    outputs,
    state,
    sample_ids
)

next_inputs_fn for GreedyEmbeddingHelper.

sample

View source

sample(
    time,
    outputs,
    state
)

sample for GreedyEmbeddingHelper.