|View source on GitHub|
Interface for implementing sampling in seq2seq decoders.
Sampler instances are used by
BasicDecoder. The normal usage of a sampler
is like below:
sampler = Sampler(init_args)
(initial_finished, initial_inputs) = sampler.initialize(input_tensors)
for time_step in range(time):
cell_output, cell_state = cell.call(cell_input, previous_state)
sample_ids = sampler.sample(time_step, cell_output, cell_state)
(finished, next_inputs, next_state) = sampler.next_inputs(
Note that all the tensor input should not be feed to Sampler as init() parameters, instead, they should be feed by decoders via initialize().
batch_size: Batch size of tensor returned by
Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.
sample_ids_dtype: DType of tensor returned by
Returns a DType. The return value might not available before the invocation of initialize().
sample_ids_shape: Shape of tensor returned by
sample, excluding the batch dimension.
TensorShape. The return value might not available before the invocation of initialize().
initialize( inputs, **kwargs )
initialize the sampler with the input tensors.
This method suppose to be only invoke once before the calling other methods of the Sampler.
inputs: A (structure of) input tensors, it could be a nested tuple or a single tensor.
**kwargs: Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter.
next_inputs( time, outputs, state, sample_ids )
(finished, next_inputs, next_state).
sample( time, outputs, state )