tfa.seq2seq.SampleEmbeddingSampler

View source on GitHub

Class SampleEmbeddingSampler

A sampler for use during inference.

Inherits From: GreedyEmbeddingSampler

Aliases:

Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.

__init__

View source

__init__(
    embedding_fn=None,
    softmax_temperature=None,
    seed=None
)

Initializer.

Args:

  • embedding_fn: (Optional) A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup. The returned tensor will be passed to the decoder input.
  • softmax_temperature: (Optional) float32 scalar, value to divide the logits by before computing the softmax. Larger values (above 1.0) result in more random samples, while smaller values push the sampling distribution towards the argmax. Must be strictly greater than 0. Defaults to 1.0.
  • seed: (Optional) The sampling seed.

Raises:

  • ValueError: if start_tokens is not a 1D tensor or end_token is not a scalar.

Properties

batch_size

Batch size of tensor returned by sample.

Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

sample_ids_dtype

DType of tensor returned by sample.

Returns a DType. The return value might not available before the invocation of initialize().

sample_ids_shape

Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape. The return value might not available before the invocation of initialize().

Methods

initialize

View source

initialize(
    embedding,
    start_tokens=None,
    end_token=None
)

Initialize the GreedyEmbeddingSampler.

Args:

  • embedding: tensor that contains embedding states matrix. It will be used to generate generate outputs with start_tokens and end_tokens. The embedding will be ignored if the embedding_fn has been provided at init().
  • start_tokens: int32 vector shaped [batch_size], the start tokens.
  • end_token: int32 scalar, the token that marks end of decoding.

Returns:

Tuple of two items: (finished, self.start_inputs).

Raises:

  • ValueError: if start_tokens is not a 1D tensor or end_token is not a scalar.

next_inputs

View source

next_inputs(
    time,
    outputs,
    state,
    sample_ids
)

next_inputs_fn for GreedyEmbeddingHelper.

sample

View source

sample(
    time,
    outputs,
    state
)

sample for SampleEmbeddingHelper.