|View source on GitHub|
A helper to use during inference with a custom sampling function.
__init__( sample_fn, sample_shape, sample_dtype, end_fn, next_inputs_fn=None )
sample_fn: A callable that takes
outputsand emits tensor
sample_shape: Either a list of integers, or a 1-D Tensor of type
int32, the shape of the each sample in the batch returned by
sample_dtype: the dtype of the sample returned by
end_fn: A callable that takes
sample_idsand emits a
[batch_size]indicating whether each sample is an end token.
next_inputs_fn: (Optional) A callable that takes
sample_idsand returns the next batch of inputs. If not provided,
sample_idsis used as the next batch of inputs.
Batch size of tensor returned by
Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.
DType of tensor returned by
Returns a DType. The return value might not available before the invocation of initialize().
Shape of tensor returned by
sample, excluding the batch dimension.
TensorShape. The return value might not available
before the invocation of initialize().
initialize the sampler with the input tensors.
This method suppose to be only invoke once before the calling other methods of the Sampler.
inputs: A (structure of) input tensors, it could be a nested tuple or a single tensor.
**kwargs: Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter.
next_inputs( time, outputs, state, sample_ids )
(finished, next_inputs, next_state).
sample( time, outputs, state )