|View source on GitHub|
Interface for implementing sampling in seq2seq decoders.
Sampler instances are used by
BasicDecoder. The normal usage of a sampler
is like below:
sampler = Sampler(init_args)
(initial_finished, initial_inputs) = sampler.initialize(input_tensors)
for time_step in range(time):
cell_output, cell_state = cell.call(cell_input, previous_state)
sample_ids = sampler.sample(time_step, cell_output, cell_state)
(finished, next_inputs, next_state) = sampler.next_inputs(
Note that all the tensor input should not be feed to Sampler as init() parameters, instead, they should be feed by decoders via initialize().
Batch size of tensor returned by
Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.
DType of tensor returned by
Returns a DType. The return value might not available before the invocation of initialize().
Shape of tensor returned by
sample, excluding the batch dimension.
TensorShape. The return value might not available
before the invocation of initialize().
initialize( inputs, **kwargs )
initialize the sampler with the input tensors.
This method suppose to be only invoke once before the calling other methods of the Sampler.
inputs: A (structure of) input tensors, it could be a nested tuple or a single tensor.
**kwargs: Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter.
next_inputs( time, outputs, state, sample_ids )
(finished, next_inputs, next_state).
sample( time, outputs, state )