A helper to use during inference with a custom sampling function.
__init__( sample_fn, sample_shape, sample_dtype, start_inputs, end_fn, next_inputs_fn=None )
sample_fn: A callable that takes
outputsand emits tensor
sample_shape: Either a list of integers, or a 1-D Tensor of type
int32, the shape of the each sample in the batch returned by
sample_dtype: the dtype of the sample returned by
start_inputs: The initial batch of inputs.
end_fn: A callable that takes
sample_idsand emits a
[batch_size]indicating whether each sample is an end token.
next_inputs_fn: (Optional) A callable that takes
sample_idsand returns the next batch of inputs. If not provided,
sample_idsis used as the next batch of inputs.
Batch size of tensor returned by
Returns a scalar int32 tensor.
DType of tensor returned by
Returns a DType.
Shape of tensor returned by
sample, excluding the batch dimension.
next_inputs( time, outputs, state, sample_ids, name=None )
(finished, next_inputs, next_state).
sample( time, outputs, state, name=None )