BERT-SNGP를 사용한 불확실성 인식 심층 언어 학습

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 보기 노트북 다운로드 TF Hub 모델 보기

SNGP 튜토리얼 에서는 불확실성을 정량화하는 능력을 향상시키기 위해 깊은 잔차 네트워크 위에 SNGP 모델을 구축하는 방법을 배웠습니다. 이 자습서에서는 SNGP를 심층 BERT 인코더 위에 구축하여 NLU(자연어 이해) 작업에 적용하여 범위 밖 쿼리를 감지하는 심층 NLU 모델의 기능을 향상시킵니다.

구체적으로 다음을 수행합니다.

  • SNGP가 강화된 BERT 모델인 BERT-SNGP를 빌드합니다.
  • CLINC 범위 밖(OOS) 의도 감지 데이터 세트를 로드합니다.
  • BERT-SNGP 모델을 훈련시킵니다.
  • 불확실성 보정 및 도메인 외 감지에서 BERT-SNGP 모델의 성능을 평가합니다.

CLINC OOS 외에도 SNGP 모델은 Jigsaw 독성 검출 과 같은 대규모 데이터 세트와 CIFAR-100ImageNet 과 같은 이미지 데이터 세트에 적용되었습니다. SNGP 및 기타 불확실성 방법의 벤치마크 결과와 종단 간 교육/평가 스크립트를 사용한 고품질 구현은 Uncertainty Baselines 벤치마크를 확인할 수 있습니다.

설정

pip uninstall -y tensorflow tf-text
pip install -U tensorflow-text-nightly
pip install -U tf-nightly
pip install -U tf-models-nightly
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.9.0-dev20220203). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

이 튜토리얼을 효율적으로 실행하려면 GPU가 필요합니다. GPU를 사용할 수 있는지 확인하십시오.

tf.__version__
'2.9.0-dev20220203'
gpus = tf.config.list_physical_devices('GPU')
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

먼저 BERT 자습서로 텍스트 분류에 따라 표준 BERT 분류기를 구현합니다. BERT 기반 인코더와 내장 ClassificationHead 를 분류기로 사용할 것입니다.

표준 BERT 모델

SNGP 모델 구축

BERT-SNGP 모델을 구현하려면 ClassificationHead 를 내장된 GaussianProcessClassificationHead 로 교체하기만 하면 됩니다. 스펙트럼 정규화는 이미 이 분류 헤드에 미리 패키징되어 있습니다. SNGP 튜토리얼 에서와 같이 모델에 공분산 재설정 콜백을 추가하면 동일한 데이터를 두 번 계산하는 것을 방지하기 위해 모델이 새 시대 시작 시 공분산 추정기를 자동으로 재설정합니다.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(
        num_classes=num_classes, 
        inner_dim=inner_dim,
        dropout_rate=dropout_rate,
        gp_cov_momentum=-1,
        temperature=30.,
        **self.classifier_kwargs)

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))
    kwargs['callbacks'].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

CLINC OOS 데이터세트 로드

이제 CLINC OOS 의도 감지 데이터 세트를 로드합니다. 이 데이터 세트에는 150개의 의도 클래스를 통해 수집된 15,000개의 사용자 음성 쿼리가 포함되어 있으며 알려진 클래스에서 다루지 않는 1000개의 OOD(도메인 외) 문장도 포함되어 있습니다.

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

기차를 만들고 데이터를 테스트합니다.

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

OOD 평가 데이터세트를 만듭니다. 이를 위해 도메인 내 테스트 데이터 clinc_test 와 도메인 외부 데이터 clinc_test_oos 를 결합합니다. 또한 도메인 내 예에 레이블 0을 할당하고 도메인 외부 예에 레이블 1을 할당합니다.

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

훈련 및 평가

먼저 기본 교육 구성을 설정합니다.

TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256

optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
                   epochs=TRAIN_EPOCHS,
                   validation_batch_size=EVAL_BATCH_SIZE, 
                   validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
Epoch 1/3
469/469 [==============================] - 219s 427ms/step - loss: 1.0725 - sparse_categorical_accuracy: 0.7870 - val_loss: 0.4358 - val_sparse_categorical_accuracy: 0.9380
Epoch 2/3
469/469 [==============================] - 198s 422ms/step - loss: 0.0885 - sparse_categorical_accuracy: 0.9797 - val_loss: 0.2424 - val_sparse_categorical_accuracy: 0.9518
Epoch 3/3
469/469 [==============================] - 199s 424ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9642
<keras.callbacks.History at 0x7fe24c0a7090>

OOD 성능 평가

모델이 익숙하지 않은 도메인 외부 쿼리를 얼마나 잘 감지할 수 있는지 평가합니다. 엄격한 평가를 위해 이전에 빌드한 OOD 평가 데이터 세트 ood_eval_dataset 을 사용합니다.

OOD 확률을 \(1 - p(x)\)로 계산합니다. 여기서 \(p(x)=softmax(logit(x))\) 은 예측 확률입니다.

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

이제 모델의 불확실성 점수 ood_probs 가 도메인 외부 레이블을 얼마나 잘 예측하는지 평가합니다. 먼저 OOD 확률 대 OOD 감지 정확도에 대한 AUPRC(정밀도 재현율 곡선 아래 면적)를 계산합니다.

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
SNGP AUPRC: 0.9039

이는 불확실성 기준 아래 CLINC OOS 벤치마크에서 보고된 SNGP 성능과 일치합니다.

다음으로, 불확실성 보정 에서 모델의 품질, 즉 모델의 예측 확률이 예측 정확도와 일치하는지 여부를 검사합니다. 예를 들어 예측 확률 \(p(x)=0.8\) 는 모델이 80% 정확하다는 것을 의미하기 때문에 잘 보정된 모델은 신뢰할 수 있는 것으로 간주됩니다.

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')

plt.show()

png

리소스 및 추가 읽을거리