Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

ELECTRA network training model.

This is an implementation of the network structure described in "ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators" (

The ElectraPretrainer allows a user to pass in two transformer models, one for generator, the other for discriminator, and instantiates the masked language model (at generator side) and classification networks (at discriminator side) that are used to create the training objectives.

generator_network A transformer network for generator, this network should output a sequence output and an optional classification output.
discriminator_network A transformer network for discriminator, this network should output a sequence output
vocab_size Size of generator output vocabulary
num_classes Number of classes to predict from the classification network for the generator network (not used now)
num_token_predictions Number of tokens to predict from the masked LM.
mlm_activation The activation (if any) to use in the masked LM and classification networks. If None, no activation will be used.
mlm_initializer The initializer (if any) to use in the masked LM and classification networks. Defaults to a Glorot uniform initializer.
output_type The output style for this network. Can be either logits or predictions.
disallow_correct Whether to disallow the generator to generate the exact same token in the original sentence

checkpoint_items Returns a dictionary of items to be additionally checkpointed.



View source

ELECTRA forward pass.

inputs A dict of all inputs, same as the standard BERT model.

outputs A dict of pretrainer model outputs, including (1) lm_outputs: A [batch_size, num_token_predictions, vocab_size] tensor indicating logits on masked positions. (2) sentence_outputs: A [batch_size, num_classes] tensor indicating logits for nsp task. (3) disc_logits: A [batch_size, sequence_length] tensor indicating logits for discriminator replaced token detection task. (4) disc_label: A [batch_size, sequence_length] tensor indicating target labels for discriminator replaced token detection task.