![]() |
BERT pretraining model.
tfm.nlp.models.BertPretrainer(
network,
num_classes,
num_token_predictions,
embedding_table=None,
activation=None,
initializer='glorot_uniform',
output='logits',
**kwargs
)
[Note] Please use the new BertPretrainerV2
for your projects.
The BertPretrainer allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives.
Methods
call
call(
inputs, training=None, mask=None
)
Calls the model on new inputs and returns the outputs as tensors.
In this case call()
just reapplies
all ops in the graph to the new inputs
(e.g. build a new computational graph from the provided inputs).
Args | |
---|---|
inputs
|
Input tensor, or dict/list/tuple of input tensors. |
training
|
Boolean or boolean scalar tensor, indicating whether to
run the Network in training mode or inference mode.
|
mask
|
A mask or list of masks. A mask can be either a boolean tensor or None (no mask). For more details, check the guide here. |
Returns | |
---|---|
A tensor if there is a single output, or a list of tensors if there are more than one outputs. |