tfa.seq2seq.Sampler

View source on GitHub

Interface for implementing sampling in seq2seq decoders.

Sampler instances are used by BasicDecoder. The normal usage of a sampler is like below:

sampler = Sampler(init_args)
(initial_finished, initial_inputs) = sampler.initialize(input_tensors)
cell_input = initial_inputs
cell_state = cell.get_initial_state(...)
for time_step in tf.range(max_output_length):
    cell_output, cell_state = cell(cell_input, cell_state)
    sample_ids = sampler.sample(time_step, cell_output, cell_state)
    (finished, cell_input, cell_state) = sampler.next_inputs(
        time_step, cell_output, cell_state, sample_ids)
    if tf.reduce_all(finished):
        break

Note that the input_tensors should not be fed to the Sampler as init() parameters. Instead, they should be fed by decoders via initialize().

batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

sample_ids_dtype DType of tensor returned by sample.

Returns a DType. The return value might not available before the invocation of initialize().

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape. The return value might not available before the invocation of initialize().

Methods

initialize

View source

initialize the sampler with the input tensors.

This method must be invoked exactly once before calling other methods of the Sampler.

Args
inputs A (structure of) input tensors, it could be a nested tuple or a single tensor.
**kwargs Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter.

Returns
(initial_finished, initial_inputs).

next_inputs

View source

Returns (finished, next_inputs, next_state).

sample

View source

Returns sample_ids.