Gaussian process-based pooling head for sentence classification.

Inherits From: ClassificationHead

This class implements a classifier head for BERT encoder that is based on the spectral-normalized neural Gaussian process (SNGP) [1]. SNGP is a simple method to improve a neural network's uncertainty quantification ability without sacrificing accuracy or lantency. It applies spectral normalization to the hidden pooler layer, and then replaces the dense output layer with a Gaussian process.

[1]: Jeremiah Liu et al. Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness. In Neural Information Processing Systems, 2020.

inner_dim The dimensionality of inner projection layer. If 0 or None then only the output projection layer is created.
num_classes Number of output classes.
cls_token_idx The index inside the sequence to pool.
activation Dense layer activation.
dropout_rate Dropout probability.
initializer Initializer for dense layer kernels.
use_spec_norm Whether to apply spectral normalization to pooler layer.
use_gp_layer Whether to use Gaussian process as the output layer.
temperature The temperature parameter to be used for mean-field approximation during inference. If None then no mean-field adjustment is applied.
**kwargs Additional keyword arguments.




View source

Returns model output.

Dring training, the model returns raw logits. During evaluation, the model returns uncertainty adjusted logits, and (optionally) the covariance matrix.

features A tensor of input features, shape (batch_size, feature_dim).
training Whether the model is in training mode.
return_covmat Whether the model should also return covariance matrix if use_gp_layer=True. During training, it is recommended to set return_covmat=False to be compatible with the standard Keras pipelines (e.g.,

logits Uncertainty-adjusted predictive logits, shape (batch_size, num_classes).
covmat (Optional) Covariance matrix, shape (batch_size, batch_size). Returned only when return_covmat=True.


View source

Resets covariance matrix of the Gaussian process layer.