TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Module: tfa.seq2seq

View source on GitHub

Ops for building neural network sequence to sequence decoders and losses.

Modules

attention_wrapper module: A powerful dynamic attention wrapper object.

basic_decoder module: A class of Decoders that may sample to generate the next input.

beam_search_decoder module: A decoder that performs beam search.

decoder module: Seq2seq layer operations for use in neural networks.

loss module: Seq2seq loss operations for use in sequence models.

sampler module: A library of sampler for use with SamplingDecoders.

Classes

class AttentionMechanism

class AttentionWrapper: Wraps another RNNCell with attention.

class AttentionWrapperState: namedtuple storing the state of a AttentionWrapper.

class BahdanauAttention: Implements Bahdanau-style (additive) attention.

class BahdanauMonotonicAttention: Monotonic attention mechanism with Bahadanau-style energy function.

class BaseDecoder: An RNN Decoder that is based on a Keras layer.

class BasicDecoder: Basic sampling decoder.

class BasicDecoderOutput: BasicDecoderOutput(rnn_output, sample_id)

class BeamSearchDecoder: BeamSearch sampling decoder.

class BeamSearchDecoderOutput: BeamSearchDecoderOutput(scores, predicted_ids, parent_ids)

class BeamSearchDecoderState: BeamSearchDecoderState(cell_state, log_probs, finished, lengths, accumulated_attention_probs)

class CustomSampler: Base abstract class that allows the user to customize sampling.

class Decoder: An RNN Decoder abstract interface object.

class FinalBeamSearchDecoderOutput: Final outputs returned by the beam search after all decoding is

class GreedyEmbeddingSampler: A sampler for use during inference.

class InferenceSampler: A helper to use during inference with a custom sampling function.

class LuongAttention: Implements Luong-style (multiplicative) attention scoring.

class LuongMonotonicAttention: Monotonic attention mechanism with Luong-style energy function.

class SampleEmbeddingSampler: A sampler for use during inference.

class Sampler: Interface for implementing sampling in seq2seq decoders.

class ScheduledEmbeddingTrainingSampler: A training sampler that adds scheduled sampling.

class ScheduledOutputTrainingSampler: A training sampler that adds scheduled sampling directly to outputs.

class SequenceLoss: Weighted cross-entropy loss for a sequence of logits.

class TrainingSampler: A Sampler for use during training.

Functions

dynamic_decode(...): Perform dynamic decoding with decoder.

gather_tree_from_array(...): Calculates the full beams for TensorArrays.

hardmax(...): Returns batched one-hot vectors.

monotonic_attention(...): Compute monotonic attention distribution from choosing probabilities.

safe_cumprod(...): Computes cumprod of x in logspace using cumsum to avoid underflow.

sequence_loss(...): Weighted cross-entropy loss for a sequence of logits.

tile_batch(...): Tile the batch dimension of a (possibly nested structure of) tensor(s)