Aprendizaje profundo de idiomas consciente de la incertidumbre con BERT-SNGP

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar libreta Ver modelo TF Hub

En el tutorial de SNGP , aprendió a construir un modelo SNGP sobre una red residual profunda para mejorar su capacidad de cuantificar su incertidumbre. En este tutorial, aplicará SNGP a una tarea de comprensión del lenguaje natural (NLU) construyéndolo sobre un codificador BERT profundo para mejorar la capacidad del modelo NLU profundo para detectar consultas fuera del alcance.

Específicamente, usted:

  • Cree BERT-SNGP, un modelo BERT aumentado con SNGP.
  • Cargue el conjunto de datos de detección de intenciones fuera de alcance (OOS) de CLINC .
  • Entrena el modelo BERT-SNGP.
  • Evalúe el rendimiento del modelo BERT-SNGP en calibración de incertidumbre y detección fuera de dominio.

Más allá de CLINC OOS, el modelo SNGP se ha aplicado a conjuntos de datos a gran escala, como la detección de toxicidad de Jigsaw , y a conjuntos de datos de imágenes, como CIFAR-100 e ImageNet . Para obtener resultados de referencia de SNGP y otros métodos de incertidumbre, así como una implementación de alta calidad con scripts de capacitación/evaluación de extremo a extremo, puede consultar la referencia de referencia de incertidumbre .

Configuración

pip uninstall -y tensorflow tf-text
pip install -U tensorflow-text-nightly
pip install -U tf-nightly
pip install -U tf-models-nightly
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.9.0-dev20220203). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

Este tutorial necesita la GPU para ejecutarse de manera eficiente. Compruebe si la GPU está disponible.

tf.__version__
'2.9.0-dev20220203'
gpus = tf.config.list_physical_devices('GPU')
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

Primero implemente un clasificador BERT estándar siguiendo el tutorial de clasificación de texto con BERT . Usaremos el codificador basado en BERT y el ClassificationHead integrado como clasificador.

Modelo BERT estándar

Construir modelo SNGP

Para implementar un modelo BERT-SNGP, solo necesita reemplazar el ClassificationHead con el GaussianProcessClassificationHead incorporado. La normalización espectral ya está preempaquetada en este cabezal de clasificación. Al igual que en el tutorial de SNGP , agregue una devolución de llamada de restablecimiento de covarianza al modelo, de modo que el modelo restablezca automáticamente el estimador de covarianza al comienzo de una nueva época para evitar contar los mismos datos dos veces.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(
        num_classes=num_classes, 
        inner_dim=inner_dim,
        dropout_rate=dropout_rate,
        gp_cov_momentum=-1,
        temperature=30.,
        **self.classifier_kwargs)

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))
    kwargs['callbacks'].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Cargar conjunto de datos CLINC OOS

Ahora cargue el conjunto de datos de detección de intenciones de CLINC OOS . Este conjunto de datos contiene 15000 consultas habladas de usuarios recopiladas en más de 150 clases de intención, también contiene 1000 oraciones fuera del dominio (OOD) que no están cubiertas por ninguna de las clases conocidas.

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

Haz el tren y prueba los datos.

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

Cree un conjunto de datos de evaluación OOD. Para esto, combine los datos de prueba en el dominio clinc_test y los datos fuera del dominio clinc_test_oos . También asignaremos la etiqueta 0 a los ejemplos dentro del dominio y la etiqueta 1 a los ejemplos fuera del dominio.

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

Formar y evaluar

Primero configure las configuraciones básicas de entrenamiento.

TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256

optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
                   epochs=TRAIN_EPOCHS,
                   validation_batch_size=EVAL_BATCH_SIZE, 
                   validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
Epoch 1/3
469/469 [==============================] - 219s 427ms/step - loss: 1.0725 - sparse_categorical_accuracy: 0.7870 - val_loss: 0.4358 - val_sparse_categorical_accuracy: 0.9380
Epoch 2/3
469/469 [==============================] - 198s 422ms/step - loss: 0.0885 - sparse_categorical_accuracy: 0.9797 - val_loss: 0.2424 - val_sparse_categorical_accuracy: 0.9518
Epoch 3/3
469/469 [==============================] - 199s 424ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9642
<keras.callbacks.History at 0x7fe24c0a7090>

Evaluar el rendimiento de OOD

Evalúe qué tan bien el modelo puede detectar las consultas desconocidas fuera del dominio. Para una evaluación rigurosa, utilice el conjunto de datos de evaluación OOD ood_eval_dataset creado anteriormente.

Calcula las probabilidades OOD como \(1 - p(x)\), donde \(p(x)=softmax(logit(x))\) es la probabilidad predictiva.

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

Ahora evalúe qué tan bien la puntuación de incertidumbre del modelo ood_probs predice la etiqueta fuera del dominio. Primero calcule el área bajo la curva de recuperación de precisión (AUPRC) para la probabilidad de OOD frente a la precisión de detección de OOD.

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
SNGP AUPRC: 0.9039

Esto coincide con el rendimiento de SNGP informado en el punto de referencia CLINC OOS en las líneas de base de incertidumbre .

A continuación, examine la calidad del modelo en la calibración de la incertidumbre , es decir, si la probabilidad predictiva del modelo se corresponde con su precisión predictiva. Un modelo bien calibrado se considera digno de confianza, ya que, por ejemplo, su probabilidad predictiva \(p(x)=0.8\) significa que el modelo es correcto el 80% de las veces.

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')

plt.show()

png

Recursos y lecturas adicionales