Module: tfm.nlp.layers

Stay organized with collections Save and categorize content based on your preferences.

Layers are the fundamental building blocks for NLP models.

They can be used to assemble new tf.keras layers or models.


class BertPackInputs: Packs tokens into model inputs for BERT.

class BertTokenizer: Wraps TF.Text's BertTokenizer with pre-defined vocab as a Keras Layer.

class BigBirdAttention: BigBird, a sparse attention mechanism.

class BigBirdMasks: Creates bigbird attention masks.

class BlockDiagFeedforward: Block diagonal feedforward layer.

class CachedAttention: Attention layer with cache used for autoregressive decoding.

class ClassificationHead: Pooling head for sentence-level classification tasks.

class CompiledTransformer: Transformer layer.

class FastWordpieceBertTokenizer: A bert tokenizer keras layer using text.FastWordpieceTokenizer.

class GatedFeedforward: Gated linear feedforward layer.

class GaussianProcessClassificationHead: Gaussian process-based pooling head for sentence classification.

class KernelAttention: A variant of efficient transformers which replaces softmax with kernels.

class KernelMask: Creates kernel attention mask.

class MaskedLM: Masked language model network head for BERT modeling.

class MaskedSoftmax: Performs a softmax with optional masking on a tensor.

class MatMulWithMargin: This layer computs a dot product matrix given two encoded inputs.

class MobileBertEmbedding: Performs an embedding lookup for MobileBERT.

class MobileBertMaskedLM: Masked language model network head for BERT modeling.

class MobileBertTransformer: Transformer block for MobileBERT.

class MultiChannelAttention: Multi-channel Attention layer.

class MultiClsHeads: Pooling heads sharing the same pooling stem.

class MultiHeadRelativeAttention: A multi-head attention layer with relative attention + position encoding.

class OnDeviceEmbedding: Performs an embedding lookup suitable for accelerator devices.

class PackBertEmbeddings: Performs packing tricks for BERT inputs to improve TPU utilization.

class PerDimScaleAttention: Learn scales for individual dims.

class PerQueryDenseHead: Pooling head used for EncT5 style models.

class PositionEmbedding: Creates a positional embedding.

class RandomFeatureGaussianProcess: Gaussian process layer with random feature approximation [1].

class ReZeroTransformer: Transformer layer with ReZero.

class RelativePositionBias: Relative position embedding via per-head bias in T5 style.

class RelativePositionEmbedding: Creates a positional embedding.

class ReuseMultiHeadAttention: MultiHeadAttention layer.

class ReuseTransformer: Transformer layer.

class SelectTopK: Select top-k + random-k tokens according to importance.

class SelfAttentionMask: Create 3D attention mask from a 2D tensor mask.

class SentencepieceTokenizer: Wraps tf_text.SentencepieceTokenizer as a Keras Layer.

class SpectralNormalization: Implements spectral normalization for Dense layer.

class SpectralNormalizationConv2D: Implements spectral normalization for Conv2D layer based on [3].

class StridedTransformerEncoderBlock: Transformer layer for packing optimization to stride over inputs.

class StridedTransformerScaffold: TransformerScaffold for packing optimization to stride over inputs.

class TNTransformerExpandCondense: Transformer layer using tensor network Expand-Condense layer.

class TalkingHeadsAttention: Implements Talking-Heads Attention.

class TokenImportanceWithMovingAvg: Routing based on per-token importance value.

class Transformer: Transformer layer.

class TransformerDecoderBlock: Single transformer layer for decoder.

class TransformerEncoderBlock: TransformerEncoderBlock layer.

class TransformerScaffold: Transformer scaffold layer.

class TransformerXL: Transformer XL.

class TransformerXLBlock: Transformer XL block.

class TwoStreamRelativeAttention: Two-stream relative self-attention for XLNet.

class VotingAttention: Voting Attention layer.


extract_gp_layer_kwargs(...): Extracts Gaussian process layer configs from a given kwarg.

extract_spec_norm_kwargs(...): Extracts spectral normalization configs from a given kwarg.

tf_function_if_eager(...): Applies the @tf.function decorator only if running in eager mode.