![]() |
An inference sampler that randomly samples from the output distribution.
Inherits From: GreedyEmbeddingSampler
, Sampler
tfa.seq2seq.SampleEmbeddingSampler(
embedding_fn: Optional[Callable] = None,
softmax_temperature: Optional[TensorLike] = None,
seed: Optional[TensorLike] = None
)
Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.
Args | |
---|---|
embedding_fn
|
(Optional) A callable that takes a vector tensor of
ids (argmax ids). The returned tensor will be passed to the
decoder input.
|
softmax_temperature
|
(Optional) float32 scalar, value to divide the
logits by before computing the softmax. Larger values (above 1.0)
result in more random samples, while smaller values push the
sampling distribution towards the argmax. Must be strictly greater
than 0. Defaults to 1.0.
|
seed
|
(Optional) The sampling seed. |
Raises | |
---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
Attributes | |
---|---|
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised. |
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType. The return value might not available before the invocation of initialize(). |
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a |
Methods
initialize
initialize(
embedding, start_tokens=None, end_token=None
)
Initialize the GreedyEmbeddingSampler.
Args | |
---|---|
embedding
|
tensor that contains embedding states matrix. It will be
used to generate generate outputs with start_tokens and end_token .
The embedding will be ignored if the embedding_fn has been provided
at init().
|
start_tokens
|
int32 vector shaped [batch_size] , the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
Returns | |
---|---|
Tuple of two items: (finished, self.start_inputs) .
|
Raises | |
---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
next_inputs
next_inputs(
time, outputs, state, sample_ids
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample(
time, outputs, state
)
sample for SampleEmbeddingHelper.