This is an implementation of the network structure described in "ELECTRA:
Pre-training Text Encoders as Discriminators Rather Than Generators" (
The ElectraPretrainer allows a user to pass in two transformer models, one for
generator, the other for discriminator, and instantiates the masked language
model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives.
A transformer network for generator, this network should
output a sequence output and an optional classification output.
A transformer network for discriminator, this network
should output a sequence output
Size of generator output vocabulary
Number of classes to predict from the classification network
for the generator network (not used now)
Number of tokens to predict from the masked LM.
The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
The output style for this network. Can be either logits or
Whether to disallow the generator to generate the exact
same token in the original sentence
Returns a dictionary of items to be additionally checkpointed.
A dict of all inputs, same as the standard BERT model.
A dict of pretrainer model outputs, including
(1) lm_outputs: A [batch_size, num_token_predictions, vocab_size]
tensor indicating logits on masked positions.
(2) sentence_outputs: A [batch_size, num_classes] tensor indicating
logits for nsp task.
(3) disc_logits: A [batch_size, sequence_length] tensor indicating
logits for discriminator replaced token detection task.
(4) disc_label: A [batch_size, sequence_length] tensor indicating
target labels for discriminator replaced token detection task.