Attend the Women in ML Symposium on December 7 Register now

tfm.nlp.models.BertTokenClassifier

Stay organized with collections Save and categorize content based on your preferences.

Token classifier model based on a BERT-style transformer-based encoder.

This is an implementation of the network structure surrounding a transformer encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" (https://arxiv.org/abs/1810.04805).

The BertTokenClassifier allows a user to pass in a transformer stack, and instantiates a token classification network based on the passed num_classes argument.

network A transformer network. This network should output a sequence output and a classification output. Furthermore, it should expose its embedding table via a get_embedding_table method.
num_classes Number of classes to predict from the classification network.
initializer The initializer (if any) to use in the classification networks. Defaults to a Glorot uniform initializer.
output The output style for this network. Can be either logits or predictions.
dropout_rate The dropout probability of the token classification head.
output_encoder_outputs Whether to include intermediate sequence output in the final output.

checkpoint_items

Methods

call

Calls the model on new inputs and returns the outputs as tensors.

In this case call() just reapplies all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs).

Args
inputs Input tensor, or dict/list/tuple of input tensors.
training Boolean or boolean scalar tensor, indicating whether to run the Network in training mode or inference mode.
mask A mask or list of masks. A mask can be either a boolean tensor or None (no mask). For more details, check the guide here.

Returns
A tensor if there is a single output, or a list of tensors if there are more than one outputs.