![]() |
Masked language model network head for BERT modeling.
tfm.nlp.layers.MaskedLM(
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
name=None,
**kwargs
)
This layer implements a masked language model based on the provided transformer based encoder. It assumes that the encoder network being passed has a "get_embedding_table()" method.
Example:
encoder=modeling.networks.BertEncoder(...)
lm_layer=MaskedLM(embedding_table=encoder.get_embedding_table())
Args | |
---|---|
embedding_table
|
The embedding table from encoder network. |
activation
|
The activation, if any, for the dense layer. |
initializer
|
The initializer for the dense layer. Defaults to a Glorot uniform initializer. |
output
|
The output style for this layer. Can be either 'logits' or 'predictions'. |
Methods
call
call(
sequence_data, masked_positions
)
This is where the layer's logic lives.
The call()
method may not create state (except in its first invocation,
wrapping the creation of variables or other resources in tf.init_scope()
).
It is recommended to create state in __init__()
, or the build()
method
that is called automatically before call()
executes the first time.
Args | |
---|---|
inputs
|
Input tensor, or dict/list/tuple of input tensors.
The first positional inputs argument is subject to special rules:
|
*args
|
Additional positional arguments. May contain tensors, although this is not recommended, for the reasons above. |
**kwargs
|
Additional keyword arguments. May contain tensors, although
this is not recommended, for the reasons above.
The following optional keyword arguments are reserved:
training : Boolean scalar tensor of Python boolean indicating
whether the call is meant for training or inference.mask : Boolean input mask. If the layer's call() method takes a
mask argument, its default value will be set to the mask generated
for inputs by the previous layer (if input did come from a layer
that generated a corresponding mask, i.e. if it came from a Keras
layer with masking support).
|
Returns | |
---|---|
A tensor or list/tuple of tensors. |