View source on GitHub |
A helper for use during inference.
Inherits From: GreedyEmbeddingHelper
tf.contrib.seq2seq.SampleEmbeddingHelper(
embedding, start_tokens, end_token, softmax_temperature=None, seed=None
)
Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.
Args | |
---|---|
embedding
|
A callable that takes a vector tensor of ids (argmax ids),
or the params argument for embedding_lookup . The returned tensor
will be passed to the decoder input.
|
start_tokens
|
int32 vector shaped [batch_size] , the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
softmax_temperature
|
(Optional) float32 scalar, value to divide the
logits by before computing the softmax. Larger values (above 1.0) result
in more random samples, while smaller values push the sampling
distribution towards the argmax. Must be strictly greater than 0.
Defaults to 1.0.
|
seed
|
(Optional) The sampling seed. |
Raises | |
---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is not a
scalar.
|
Attributes | |
---|---|
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor. |
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType. |
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a |
Methods
initialize
initialize(
name=None
)
Returns (initial_finished, initial_inputs)
.
next_inputs
next_inputs(
time, outputs, state, sample_ids, name=None
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample(
time, outputs, state, name=None
)
sample for SampleEmbeddingHelper.