Module: tfa.seq2seq

View source on GitHub

Additional ops for building neural network sequence to sequence decoders and



attention_wrapper module: A powerful dynamic attention wrapper object.

basic_decoder module: A class of Decoders that may sample to generate the next input.

beam_search_decoder module: A decoder that performs beam search.

decoder module: Seq2seq layer operations for use in neural networks.

loss module: Seq2seq loss operations for use in sequence models.

sampler module: A library of sampler for use with SamplingDecoders.


class AttentionMechanism

class AttentionWrapper: Wraps another RNNCell with attention.

class AttentionWrapperState: namedtuple storing the state of a AttentionWrapper.

class BahdanauAttention: Implements Bahdanau-style (additive) attention.

class BahdanauMonotonicAttention: Monotonic attention mechanism with Bahadanau-style energy function.

class BaseDecoder: An RNN Decoder that is based on a Keras layer.

class BasicDecoder: Basic sampling decoder.

class BasicDecoderOutput: Outputs of a BasicDecoder step.

class BeamSearchDecoder: BeamSearch sampling decoder.

class BeamSearchDecoderOutput: Outputs of a BeamSearchDecoder step.

class BeamSearchDecoderState: State of a BeamSearchDecoder.

class CustomSampler: Base abstract class that allows the user to customize sampling.

class Decoder: An RNN Decoder abstract interface object.

class FinalBeamSearchDecoderOutput: Final outputs returned by the beam search after all decoding is

class GreedyEmbeddingSampler: A sampler for use during inference.

class InferenceSampler: A helper to use during inference with a custom sampling function.

class LuongAttention: Implements Luong-style (multiplicative) attention scoring.

class LuongMonotonicAttention: Monotonic attention mechanism with Luong-style energy function.

class SampleEmbeddingSampler: A sampler for use during inference.

class Sampler: Interface for implementing sampling in seq2seq decoders.

class ScheduledEmbeddingTrainingSampler: A training sampler that adds scheduled sampling.

class ScheduledOutputTrainingSampler: A training sampler that adds scheduled sampling directly to outputs.

class SequenceLoss: Weighted cross-entropy loss for a sequence of logits.

class TrainingSampler: A Sampler for use during training.


dynamic_decode(...): Perform dynamic decoding with decoder.

gather_tree(...): Calculates the full beams from the per-step ids and parent beam ids.

gather_tree_from_array(...): Calculates the full beams for TensorArrays.

hardmax(...): Returns batched one-hot vectors.

monotonic_attention(...): Compute monotonic attention distribution from choosing probabilities.

safe_cumprod(...): Computes cumprod of x in logspace using cumsum to avoid underflow.

sequence_loss(...): Weighted cross-entropy loss for a sequence of logits.

tile_batch(...): Tile the batch dimension of a (possibly nested structure of) tensor(s)