Save the date! Google I/O returns May 18-20 Register now

tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper

View source on GitHub

A training helper that adds scheduled sampling.

Inherits From: TrainingHelper

Returns -1s for sample_ids where no sampling took place; valid sample id values elsewhere.

inputs A (structure of) input tensors.
sequence_length An int32 vector tensor.
embedding A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup.
sampling_probability A 0D float32 tensor: the probability of sampling categorically from the output ids instead of reading directly from the inputs.
time_major Python bool. Whether the tensors in inputs are time major. If False (default), they are assumed to be batch major.
seed The sampling seed.
scheduling_seed The schedule decision rule sampling seed.
name Name scope for any created operations.

ValueError if sampling_probability is not a scalar or vector.

batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor.

inputs

sample_ids_dtype DType of tensor returned by sample.

Returns a DType.

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape.

sequence_length

Methods

initialize

View source

Returns (initial_finished, initial_inputs).

next_inputs

View source

next_inputs_fn for TrainingHelper.

sample

View source

Returns sample_ids.