tfa.seq2seq.ScheduledEmbeddingTrainingSampler

View source on GitHub

Class ScheduledEmbeddingTrainingSampler

A training sampler that adds scheduled sampling.

Inherits From: TrainingSampler

Aliases:

Returns -1s for sample_ids where no sampling took place; valid sample id values elsewhere.

__init__

View source

__init__(
    sampling_probability,
    embedding_fn=None,
    time_major=False,
    seed=None,
    scheduling_seed=None
)

Initializer.

Args:

  • sampling_probability: A float32 0-D or 1-D tensor: the probability of sampling categorically from the output ids instead of reading directly from the inputs.
  • embedding_fn: A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup.
  • time_major: Python bool. Whether the tensors in inputs are time major. If False (default), they are assumed to be batch major.
  • seed: The sampling seed.
  • scheduling_seed: The schedule decision rule sampling seed.

Raises:

  • ValueError: if sampling_probability is not a scalar or vector.

Properties

batch_size

Batch size of tensor returned by sample.

Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

sample_ids_dtype

DType of tensor returned by sample.

Returns a DType. The return value might not available before the invocation of initialize().

sample_ids_shape

Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape. The return value might not available before the invocation of initialize().

Methods

initialize

View source

initialize(
    inputs,
    sequence_length=None,
    embedding=None
)

Initialize the TrainSampler.

Args:

  • inputs: A (structure of) input tensors.
  • sequence_length: An int32 vector tensor.

Returns:

(finished, next_inputs), a tuple of two items. The first item is a boolean vector to indicate whether the item in the batch has finished. The second item is the first slide of input data based on the timestep dimension (usually the second dim of the input).

next_inputs

View source

next_inputs(
    time,
    outputs,
    state,
    sample_ids
)

Returns (finished, next_inputs, next_state).

sample

View source

sample(
    time,
    outputs,
    state
)

Returns sample_ids.