Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

tfm.nlp.models.BertClassifier

Classifier model based on a BERT-style transformer-based encoder.

This is an implementation of the network structure surrounding a transformer encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" (https://arxiv.org/abs/1810.04805).

The BertClassifier allows a user to pass in a transformer stack, and instantiates a classification network based on the passed num_classes argument. If num_classes is set to 1, a regression network is instantiated.

network A transformer network. This network should output a sequence output and a classification output. Furthermore, it should expose its embedding table via a "get_embedding_table" method.
num_classes Number of classes to predict from the classification network.
initializer The initializer (if any) to use in the classification networks. Defaults to a Glorot uniform initializer.
dropout_rate The dropout probability of the cls head.
use_encoder_pooler Whether to use the pooler layer pre-defined inside the encoder.
head_name Name of the classification head.
cls_head (Optional) The layer instance to use for the classifier head. It should take in the output from network and produce the final logits. If set, the arguments ('num_classes', 'initializer', 'dropout_rate', 'use_encoder_pooler', 'head_name') will be ignored.

checkpoint_items

Methods

call

Calls the model on new inputs and returns the outputs as tensors.

In this case call() just reapplies all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs).

Args
inputs Input tensor, or dict/list/tuple of input tensors.
training Boolean or boolean scalar tensor, indicating whether to run the Network in training mode or inference mode.
mask A mask or list of masks. A mask can be either a boolean tensor or None (no mask). For more details, check the guide here.

Returns
A tensor if there is a single output, or a list of tensors if there are more than one outputs.