Pooling head for sentence-level classification tasks.
tfm.nlp.layers.ClassificationHead(
inner_dim,
num_classes,
cls_token_idx=0,
activation='tanh',
dropout_rate=0.0,
initializer='glorot_uniform',
**kwargs
)
Args |
inner_dim
|
The dimensionality of inner projection layer. If 0 or None
then only the output projection layer is created.
|
num_classes
|
Number of output classes.
|
cls_token_idx
|
The index inside the sequence to pool.
|
activation
|
Dense layer activation.
|
dropout_rate
|
Dropout probability.
|
initializer
|
Initializer for dense layer kernels.
|
**kwargs
|
Keyword arguments.
|
Attributes |
checkpoint_items
|
|
Methods
call
View source
call(
features: tf.Tensor, only_project: bool = False
)
Implements call().
Args |
features
|
a rank-3 Tensor when self.inner_dim is specified, otherwise
it is a rank-2 Tensor.
|
only_project
|
a boolean. If True, we return the intermediate Tensor
before projecting to class logits.
|
Returns |
a Tensor, if only_project is True, shape= [batch size, hidden size].
If only_project is False, shape= [batch size, num classes].
|