![]() |
A training sampler that adds scheduled sampling directly to outputs.
Inherits From: TrainingSampler
, Sampler
tfa.seq2seq.ScheduledOutputTrainingSampler(
sampling_probability: tfa.types.TensorLike
,
time_major: bool = False,
seed: Optional[int] = None,
next_inputs_fn: Optional[Callable] = None
)
Returns False for sample_ids where no sampling took place; True elsewhere.
Args | |
---|---|
sampling_probability
|
A float32 scalar tensor: the probability of
sampling from the outputs instead of reading directly from the
inputs.
|
time_major
|
Python bool. Whether the tensors in inputs are time
major. If False (default), they are assumed to be batch major.
|
seed
|
The sampling seed. |
next_inputs_fn
|
(Optional) callable to apply to the RNN outputs to
create the next input when sampling. If None (default), the RNN
outputs will be used as the next inputs.
|
Raises | |
---|---|
ValueError
|
if sampling_probability is not a scalar or vector.
|
Attributes | |
---|---|
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised. |
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType. The return value might not available before the invocation of initialize(). |
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a |
Methods
initialize
initialize(
inputs, sequence_length=None, mask=None, auxiliary_inputs=None
)
Initialize the TrainSampler.
Args | |
---|---|
inputs
|
A (structure of) input tensors. |
sequence_length
|
An int32 vector tensor. |
mask
|
A boolean 2D tensor. |
Returns | |
---|---|
(finished, next_inputs), a tuple of two items. The first item is a boolean vector to indicate whether the item in the batch has finished. The second item is the first slide of input data based on the timestep dimension (usually the second dim of the input). |
next_inputs
next_inputs(
time, outputs, state, sample_ids
)
Returns (finished, next_inputs, next_state)
.
sample
sample(
time, outputs, state
)
Returns sample_ids
.