tfa.seq2seq.CustomSampler

View source on GitHub

Base abstract class that allows the user to customize sampling.

Inherits From: Sampler

Args:

  • initialize_fn: callable that returns (finished, next_inputs) for the first iteration.
  • sample_fn: callable that takes (time, outputs, state) and emits tensor sample_ids.
  • next_inputs_fn: callable that takes (time, outputs, state, sample_ids) and emits (finished, next_inputs, next_state).
  • sample_ids_shape: Either a list of integers, or a 1-D Tensor of type int32, the shape of each value in the sample_ids batch. Defaults to a scalar.
  • sample_ids_dtype: The dtype of the sample_ids tensor. Defaults to int32.

Attributes:

  • batch_size: Batch size of tensor returned by sample.

    Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

  • sample_ids_dtype: DType of tensor returned by sample.

    Returns a DType. The return value might not available before the invocation of initialize().

  • sample_ids_shape: Shape of tensor returned by sample, excluding the batch dimension.

    Returns a TensorShape. The return value might not available before the invocation of initialize().

Methods

initialize

View source

initialize the sampler with the input tensors.

This method suppose to be only invoke once before the calling other methods of the Sampler.

Args:

  • inputs: A (structure of) input tensors, it could be a nested tuple or a single tensor.
  • **kwargs: Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter.

Returns:

(initial_finished, initial_inputs).

next_inputs

View source

Returns (finished, next_inputs, next_state).

sample

View source

Returns sample_ids.