|View source on GitHub|
A helper to use during inference with a custom sampling function.
tfa.seq2seq.InferenceSampler( sample_fn: Union[TensorLike, Callable], sample_shape:
TensorLike, sample_dtype: tf.int32, end_fn: Callable, next_inputs_fn: Optional[Callable] = None )
sample_fn: A callable that takes
outputsand emits tensor
sample_shape: Either a list of integers, or a 1-D Tensor of type
int32, the shape of the each sample in the batch returned by
sample_dtype: the dtype of the sample returned by
end_fn: A callable that takes
sample_idsand emits a
[batch_size]indicating whether each sample is an end token.
next_inputs_fn: (Optional) A callable that takes
sample_idsand returns the next batch of inputs. If not provided,
sample_idsis used as the next batch of inputs.
batch_size: Batch size of tensor returned by
Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.
sample_ids_dtype: DType of tensor returned by
Returns a DType. The return value might not available before the invocation of initialize().
sample_ids_shape: Shape of tensor returned by
sample, excluding the batch dimension.
TensorShape. The return value might not available before the invocation of initialize().
initialize( start_inputs )
initialize the sampler with the input tensors.
This method suppose to be only invoke once before the calling other methods of the Sampler.
inputs: A (structure of) input tensors, it could be a nested tuple or a single tensor.
**kwargs: Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter.
next_inputs( time, outputs, state, sample_ids )
(finished, next_inputs, next_state).
sample( time, outputs, state )