Unsicherheitsbewusstes Deep Learning mit SNGP

Auf TensorFlow.org ansehen In Google Colab ausführen Auf GitHub ansehen Notizbuch herunterladen

Bei sicherheitskritischen KI-Anwendungen (z. B. medizinische Entscheidungsfindung und autonomes Fahren) oder bei denen die Daten von Natur aus verrauscht sind (z. B. Verstehen natürlicher Sprache), ist es für einen tiefen Klassifikator wichtig, seine Unsicherheit zuverlässig zu quantifizieren. Der Tiefenklassifikator sollte sich seiner eigenen Grenzen bewusst sein und wann er die Kontrolle an die menschlichen Experten abgeben sollte. Dieses Tutorial zeigt , wie ein tief Klassifikator der Fähigkeit zur Verbesserung der Unsicherheit bei der Quantifizierung unter Verwendung einer Technik Spectral-normalisierte Neural Gauß - Prozess (genannt SNGP ).

Die Kernidee von SNGP ist ein tiefes Klassifiziergeräts Abstand Bewusstsein auf das Netzwerk durch Anwendung einfache Änderungen zu verbessern. Ein Abstand Bewusstsein des Modells ist ein Maß dafür , wie seine prädiktive Wahrscheinlichkeit spiegelt den Abstand zwischen dem Testbeispiel und die Trainingsdaten. Dies ist eine wünschenswerte Eigenschaft , die für Gold-Standard probablistic Modelle verfügbar ist ( zum Beispiel der Gauß - Prozess mit RBF - Kernel) , wird aber bei den Modellen mit tiefen neuronalen Netzen fehlt. SNGP bietet eine einfache Möglichkeit, dieses Verhalten des Gauß-Prozesses in einen tiefen Klassifikator zu injizieren, während seine Vorhersagegenauigkeit beibehalten wird.

Dieses Tutorial implementiert eine tiefe Restnetz (RESNET) -basierten SNGP Modell auf die beiden Monde Dataset, und vergleicht seine Unsicherheit Oberfläche mit der von zwei anderen beliebten Unsicherheit nähert sich - Monte Carlo Dropout und Tief Ensemble ).

Dieses Tutorial veranschaulicht das SNGP-Modell auf einem Spielzeug-2D-Datensatz. Ein Beispiel SNGP zu einer realen Sprachverstehen Aufgabe der Anwendung mit BERT-Basis finden Sie in der sehen SNGP-BERT - Tutorial . Für qualitativ hochwertige Implementierungen von SNGP Modell (und vielen anderen Unsicherheit Methoden) auf einer Vielzahl von Benchmark - Datensatz (zB CIFAR-100 , IMAGEnet , Jigsaw Toxizität Erkennung , usw.) finden Sie in die Check - out Unsicherheit Baselines Benchmark.

Über SNGP

Spectral-normalisierte Neural Gauß - Prozess (SNGP) ist ein einfacher Ansatz eines tiefen Klassifiziergeräts Unsicherheit Qualität zu verbessern und gleichzeitig ein ähnliches Maß an Genauigkeit und Latenz beibehalten wird . Bei einem tiefen Restnetz nimmt SNGP zwei einfache Änderungen am Modell vor:

  • Es wendet eine Spektralnormierung auf die verborgenen Restschichten an.
  • Es ersetzt die dichte Ausgabeschicht durch eine Gaußsche Prozessschicht.

SNGP

Im Vergleich zu anderen Unsicherheitsansätzen (z. B. Monte-Carlo-Aussteiger oder Deep-Ensemble) hat SNGP mehrere Vorteile:

  • Es funktioniert für eine breite Palette von hochmodernen Rest-basierten Architekturen (z. B. (Wide) ResNet, DenseNet, BERT usw.).
  • Es handelt sich um ein Einzelmodellverfahren (dh es beruht nicht auf der Gesamtmittelung). Daher SNGP hat ein ähnliches Maß an Latenz als ein einzelnes determinis Netzwerk und kann leicht auf große Datenmengen wie skaliert wird IMAGEnet und Jigsaw Klassifizierung Toxic Kommentare .
  • Es hat starke out-of-Domain Erkennungsleistung aufgrund der Entfernung Bewusstsein Eigenschaft.

Die Nachteile dieser Methode sind:

  • Die prädiktive Unsicherheit eines SNGP wird berechnet unter Verwendung der Laplace - Approximation . Daher unterscheidet sich die posteriore Unsicherheit von SNGP theoretisch von der eines exakten Gauß'schen Prozesses.

  • Das SNGP-Training benötigt zu Beginn einer neuen Epoche einen Kovarianz-Reset-Schritt. Dies kann eine Trainingspipeline um ein kleines Maß an zusätzlicher Komplexität erhöhen. Dieses Tutorial zeigt eine einfache Möglichkeit, dies mithilfe von Keras-Callbacks zu implementieren.

Einrichten

pip install tf-models-nightly
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import sklearn.datasets

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as nlp_layers

Visualisierungsmakros definieren

plt.rcParams['figure.dpi'] = 140

DEFAULT_X_RANGE = (-3.5, 3.5)
DEFAULT_Y_RANGE = (-2.5, 2.5)
DEFAULT_CMAP = colors.ListedColormap(["#377eb8", "#ff7f00"])
DEFAULT_NORM = colors.Normalize(vmin=0, vmax=1,)
DEFAULT_N_GRID = 100

Der Zwei-Mond-Datensatz

Erstellen Sie die trainining und Auswertung Datensätze aus dem zwei Mond - Datensatz .

def make_training_data(sample_size=500):
  """Create two moon training dataset."""
  train_examples, train_labels = sklearn.datasets.make_moons(
      n_samples=2 * sample_size, noise=0.1)

  # Adjust data position slightly.
  train_examples[train_labels == 0] += [-0.1, 0.2]
  train_examples[train_labels == 1] += [0.1, -0.2]

  return train_examples, train_labels

Bewerten Sie das Vorhersageverhalten des Modells über den gesamten 2D-Eingaberaum.

def make_testing_data(x_range=DEFAULT_X_RANGE, y_range=DEFAULT_Y_RANGE, n_grid=DEFAULT_N_GRID):
  """Create a mesh grid in 2D space."""
  # testing data (mesh grid over data space)
  x = np.linspace(x_range[0], x_range[1], n_grid)
  y = np.linspace(y_range[0], y_range[1], n_grid)
  xv, yv = np.meshgrid(x, y)
  return np.stack([xv.flatten(), yv.flatten()], axis=-1)

Um die Modellunsicherheit zu bewerten, fügen Sie einen Datensatz außerhalb der Domäne (OOD) hinzu, der zu einer dritten Klasse gehört. Das Modell sieht diese OOD-Beispiele während des Trainings nie.

def make_ood_data(sample_size=500, means=(2.5, -1.75), vars=(0.01, 0.01)):
  return np.random.multivariate_normal(
      means, cov=np.diag(vars), size=sample_size)
# Load the train, test and OOD datasets.
train_examples, train_labels = make_training_data(
    sample_size=500)
test_examples = make_testing_data()
ood_examples = make_ood_data(sample_size=500)

# Visualize
pos_examples = train_examples[train_labels == 0]
neg_examples = train_examples[train_labels == 1]

plt.figure(figsize=(7, 5.5))

plt.scatter(pos_examples[:, 0], pos_examples[:, 1], c="#377eb8", alpha=0.5)
plt.scatter(neg_examples[:, 0], neg_examples[:, 1], c="#ff7f00", alpha=0.5)
plt.scatter(ood_examples[:, 0], ood_examples[:, 1], c="red", alpha=0.1)

plt.legend(["Postive", "Negative", "Out-of-Domain"])

plt.ylim(DEFAULT_Y_RANGE)
plt.xlim(DEFAULT_X_RANGE)

plt.show()

png

Hier stehen Blau und Orange für die positiven und negativen Klassen und das Rot für die OOD-Daten. Von einem Modell, das die Unsicherheit gut quantifiziert, wird erwartet, dass es in der Nähe von Trainingsdaten (dh $p(x_{test})$ nahe 0 oder 1) zuverlässig ist und unsicher ist, wenn es weit von den Trainingsdatenbereichen entfernt ist (dh $p(x_{test})$ nahe 0,5).

Das deterministische Modell

Modell definieren

Beginnen Sie mit dem (Basis-)deterministischen Modell: einem mehrschichtigen Residualnetzwerk (ResNet) mit Dropout-Regularisierung.

Dieses Tutorial verwendet ein 6-Layer-ResNet mit 128 versteckten Einheiten.

resnet_config = dict(num_classes=2, num_layers=6, num_hidden=128)
resnet_model = DeepResNet(**resnet_config)
2021-07-01 01:22:35.584313: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:35.591002: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:35.591739: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:35.593055: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-07-01 01:22:35.593616: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:35.594337: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:35.594975: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:36.216660: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:36.217484: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:36.218273: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-01 01:22:36.218979: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14646 MB memory:  -> device: 0, name: NVIDIA Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0
resnet_model.build((None, 2))
resnet_model.summary()
Model: "deep_res_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  384       
_________________________________________________________________
dense_1 (Dense)              multiple                  16512     
_________________________________________________________________
dense_2 (Dense)              multiple                  16512     
_________________________________________________________________
dense_3 (Dense)              multiple                  16512     
_________________________________________________________________
dense_4 (Dense)              multiple                  16512     
_________________________________________________________________
dense_5 (Dense)              multiple                  16512     
_________________________________________________________________
dense_6 (Dense)              multiple                  16512     
_________________________________________________________________
dense_7 (Dense)              multiple                  258       
=================================================================
Total params: 99,714
Trainable params: 99,330
Non-trainable params: 384
_________________________________________________________________

Zugmodell

Konfigurieren Sie die Trainingsparameter verwenden SparseCategoricalCrossentropy als Verlustfunktion und der Adam - Optimierer.

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.keras.metrics.SparseCategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

train_config = dict(loss=loss, metrics=metrics, optimizer=optimizer)

Trainieren Sie das Modell für 100 Epochen mit Losgröße 128.

fit_config = dict(batch_size=128, epochs=100)
resnet_model.compile(**train_config)
resnet_model.fit(train_examples, train_labels, **fit_config)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5102: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Epoch 1/100
2021-07-01 01:22:36.769520: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
8/8 [==============================] - 1s 3ms/step - loss: 0.5138 - sparse_categorical_accuracy: 0.7110
Epoch 2/100
8/8 [==============================] - 0s 3ms/step - loss: 0.2955 - sparse_categorical_accuracy: 0.8810
Epoch 3/100
8/8 [==============================] - 0s 3ms/step - loss: 0.2095 - sparse_categorical_accuracy: 0.9250
Epoch 4/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1676 - sparse_categorical_accuracy: 0.9290
Epoch 5/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1459 - sparse_categorical_accuracy: 0.9410
Epoch 6/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1401 - sparse_categorical_accuracy: 0.9350
Epoch 7/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1311 - sparse_categorical_accuracy: 0.9360
Epoch 8/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1196 - sparse_categorical_accuracy: 0.9450
Epoch 9/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1175 - sparse_categorical_accuracy: 0.9460
Epoch 10/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1109 - sparse_categorical_accuracy: 0.9460
Epoch 11/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1102 - sparse_categorical_accuracy: 0.9480
Epoch 12/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1005 - sparse_categorical_accuracy: 0.9490
Epoch 13/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0989 - sparse_categorical_accuracy: 0.9490
Epoch 14/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1006 - sparse_categorical_accuracy: 0.9480
Epoch 15/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0935 - sparse_categorical_accuracy: 0.9550
Epoch 16/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0958 - sparse_categorical_accuracy: 0.9560
Epoch 17/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0879 - sparse_categorical_accuracy: 0.9520
Epoch 18/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0866 - sparse_categorical_accuracy: 0.9540
Epoch 19/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0857 - sparse_categorical_accuracy: 0.9530
Epoch 20/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0847 - sparse_categorical_accuracy: 0.9540
Epoch 21/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0791 - sparse_categorical_accuracy: 0.9570
Epoch 22/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0796 - sparse_categorical_accuracy: 0.9600
Epoch 23/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9590
Epoch 24/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0832 - sparse_categorical_accuracy: 0.9630
Epoch 25/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0797 - sparse_categorical_accuracy: 0.9600
Epoch 26/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0750 - sparse_categorical_accuracy: 0.9640
Epoch 27/100
8/8 [==============================] - 0s 6ms/step - loss: 0.0695 - sparse_categorical_accuracy: 0.9690
Epoch 28/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0683 - sparse_categorical_accuracy: 0.9670
Epoch 29/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0724 - sparse_categorical_accuracy: 0.9680
Epoch 30/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0674 - sparse_categorical_accuracy: 0.9740
Epoch 31/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0659 - sparse_categorical_accuracy: 0.9690
Epoch 32/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0630 - sparse_categorical_accuracy: 0.9720
Epoch 33/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0610 - sparse_categorical_accuracy: 0.9760
Epoch 34/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0643 - sparse_categorical_accuracy: 0.9690
Epoch 35/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0581 - sparse_categorical_accuracy: 0.9770
Epoch 36/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0607 - sparse_categorical_accuracy: 0.9800
Epoch 37/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0592 - sparse_categorical_accuracy: 0.9740
Epoch 38/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0538 - sparse_categorical_accuracy: 0.9810
Epoch 39/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0605 - sparse_categorical_accuracy: 0.9770
Epoch 40/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0532 - sparse_categorical_accuracy: 0.9840
Epoch 41/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0497 - sparse_categorical_accuracy: 0.9820
Epoch 42/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0509 - sparse_categorical_accuracy: 0.9820
Epoch 43/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0571 - sparse_categorical_accuracy: 0.9800
Epoch 44/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0545 - sparse_categorical_accuracy: 0.9800
Epoch 45/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0459 - sparse_categorical_accuracy: 0.9830
Epoch 46/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0488 - sparse_categorical_accuracy: 0.9840
Epoch 47/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0388 - sparse_categorical_accuracy: 0.9890
Epoch 48/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0441 - sparse_categorical_accuracy: 0.9860
Epoch 49/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0376 - sparse_categorical_accuracy: 0.9890
Epoch 50/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0420 - sparse_categorical_accuracy: 0.9870
Epoch 51/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0397 - sparse_categorical_accuracy: 0.9870
Epoch 52/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0384 - sparse_categorical_accuracy: 0.9880
Epoch 53/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0354 - sparse_categorical_accuracy: 0.9890
Epoch 54/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0377 - sparse_categorical_accuracy: 0.9880
Epoch 55/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0372 - sparse_categorical_accuracy: 0.9920
Epoch 56/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0335 - sparse_categorical_accuracy: 0.9890
Epoch 57/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0338 - sparse_categorical_accuracy: 0.9870
Epoch 58/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0383 - sparse_categorical_accuracy: 0.9870
Epoch 59/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0325 - sparse_categorical_accuracy: 0.9880
Epoch 60/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0297 - sparse_categorical_accuracy: 0.9900
Epoch 61/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0271 - sparse_categorical_accuracy: 0.9900
Epoch 62/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0314 - sparse_categorical_accuracy: 0.9890
Epoch 63/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0334 - sparse_categorical_accuracy: 0.9870
Epoch 64/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0303 - sparse_categorical_accuracy: 0.9920
Epoch 65/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0292 - sparse_categorical_accuracy: 0.9920
Epoch 66/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0272 - sparse_categorical_accuracy: 0.9920
Epoch 67/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0320 - sparse_categorical_accuracy: 0.9900
Epoch 68/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0290 - sparse_categorical_accuracy: 0.9900
Epoch 69/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0273 - sparse_categorical_accuracy: 0.9910
Epoch 70/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9910
Epoch 71/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0276 - sparse_categorical_accuracy: 0.9920
Epoch 72/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0226 - sparse_categorical_accuracy: 0.9940
Epoch 73/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0232 - sparse_categorical_accuracy: 0.9900
Epoch 74/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9890
Epoch 75/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0231 - sparse_categorical_accuracy: 0.9930
Epoch 76/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0257 - sparse_categorical_accuracy: 0.9890
Epoch 77/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0244 - sparse_categorical_accuracy: 0.9910
Epoch 78/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0230 - sparse_categorical_accuracy: 0.9930
Epoch 79/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0210 - sparse_categorical_accuracy: 0.9940
Epoch 80/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0222 - sparse_categorical_accuracy: 0.9910
Epoch 81/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0187 - sparse_categorical_accuracy: 0.9940
Epoch 82/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0186 - sparse_categorical_accuracy: 0.9920
Epoch 83/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0229 - sparse_categorical_accuracy: 0.9910
Epoch 84/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0176 - sparse_categorical_accuracy: 0.9930
Epoch 85/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0184 - sparse_categorical_accuracy: 0.9920
Epoch 86/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0193 - sparse_categorical_accuracy: 0.9900
Epoch 87/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9930
Epoch 88/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9930
Epoch 89/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0170 - sparse_categorical_accuracy: 0.9920
Epoch 90/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0164 - sparse_categorical_accuracy: 0.9930
Epoch 91/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0186 - sparse_categorical_accuracy: 0.9940
Epoch 92/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0184 - sparse_categorical_accuracy: 0.9920
Epoch 93/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9940
Epoch 94/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0162 - sparse_categorical_accuracy: 0.9940
Epoch 95/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9930
Epoch 96/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0134 - sparse_categorical_accuracy: 0.9960
Epoch 97/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9970
Epoch 98/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0173 - sparse_categorical_accuracy: 0.9930
Epoch 99/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0152 - sparse_categorical_accuracy: 0.9930
Epoch 100/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0158 - sparse_categorical_accuracy: 0.9950
<keras.callbacks.History at 0x7f034042e290>

Unsicherheit visualisieren

Visualisieren Sie nun die Vorhersagen des deterministischen Modells. Zeichnen Sie zuerst die Klassenwahrscheinlichkeit:

$$p(x) = softmax(logit(x))$$
resnet_logits = resnet_model(test_examples)
resnet_probs = tf.nn.softmax(resnet_logits, axis=-1)[:, 0]  # Take the probability for class 0.
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_probs, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Class Probability, Deterministic Model")

plt.show()

png

In diesem Diagramm sind Gelb und Violett die Vorhersagewahrscheinlichkeiten für die beiden Klassen. Das deterministische Modell hat gute Arbeit geleistet, um die beiden bekannten Klassen (blau und orange) mit einer nichtlinearen Entscheidungsgrenze zu klassifizieren. Es ist jedoch nicht entfernungs bewusst, und die nie gesehen rot out-of-Domain (OOD) Beispiele getrost als die orangefarbenen Klasse eingestuft.

Visualisieren Sie die Modellunsicherheit durch die Berechnung von prädiktiven Varianz :

$$var(x) = p(x) * (1 - p(x))$$
resnet_uncertainty = resnet_probs * (1 - resnet_probs)
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_uncertainty, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Predictive Uncertainty, Deterministic Model")

plt.show()

png

In diesem Diagramm bedeutet Gelb eine hohe Unsicherheit und Violett eine geringe Unsicherheit. Die Unsicherheit eines deterministischen ResNet hängt nur von der Entfernung der Testbeispiele von der Entscheidungsgrenze ab. Dies führt dazu, dass das Modell zu selbstsicher ist, wenn es sich außerhalb des Trainingsbereichs befindet. Der nächste Abschnitt zeigt, wie sich SNGP bei diesem Datensatz anders verhält.

Das SNGP-Modell

SNGP-Modell definieren

Lassen Sie uns nun das SNGP-Modell implementieren. Sowohl die SNGP Komponenten, SpectralNormalization und RandomFeatureGaussianProcess , sind erhältlich bei der tensorflow_model hat eingebaute in Schichten .

SNGP

Schauen wir uns diese beiden Komponenten genauer an. (Sie können auch auf den springen Die SNGP Modell Abschnitt zu sehen , wie das vollständige Modell implementiert ist.)

Spektralnormalisierungs-Wrapper

SpectralNormalization ist ein Keras Schicht - Wrapper. Es kann wie folgt auf eine vorhandene Dichte Ebene angewendet werden:

dense = tf.keras.layers.Dense(units=10)
dense = nlp_layers.SpectralNormalization(dense, norm_multiplier=0.9)

Spectral Normalisierungs reguliert das versteckte Gewicht $ W $ durch allmähliche seine spektrale norm Führung (dh des größten Eigenwerts von $ $ W) in Richtung auf dem Zielwert norm_multiplier .

Die Schicht des Gaußschen Prozesses (GP)

RandomFeatureGaussianProcess implementiert eine zufallsmerkmalbasierte Annäherung an ein Gaußschen Prozessmodell, das Ende-zu-Ende - trainierbar mit einem tiefen neuronales Netz. Unter der Haube implementiert die Gaußsche Prozessschicht ein zweischichtiges Netzwerk:

$$logits(x) = \Phi(x) \beta, \quad \Phi(x)=\sqrt{\frac{2}{M} } * cos(Wx + b)$$

Hier ist $x$ die Eingabe, und $W$ und $b$ sind eingefrorene Gewichte, die zufällig aus Gauß- bzw. Gleichverteilungen initialisiert werden. (Daher werden $\Phi(x)$ als "zufällige Merkmale" bezeichnet.) $\beta$ ist das erlernbare Kernel-Gewicht ähnlich dem einer Dichteschicht.

batch_size = 32
input_dim = 1024
num_classes = 10
gp_layer = nlp_layers.RandomFeatureGaussianProcess(units=num_classes,
                                               num_inducing=1024,
                                               normalize_input=False,
                                               scale_random_features=True,
                                               gp_cov_momentum=-1)

Die Hauptparameter der GP-Schichten sind:

  • units : Die Dimension des Ausgang Logits.
  • num_inducing : Die Dimension $ M $ der versteckten Gewicht $ W $. Standardwert auf 1024.
  • normalize_input : Ob Schicht Normalisierung an den Eingang $ x $ anzuwenden.
  • scale_random_features : Ob die Waage $ \ sqrt {2 / M} $ , um den versteckten Ausgang zu übernehmen.
  • gp_cov_momentum steuert , wie das Modell Kovarianz berechnet. Bei einem positiven Wert (z. B. 0,999) wird die Kovarianzmatrix unter Verwendung der impulsbasierten gleitenden Durchschnittsaktualisierung berechnet (ähnlich der Batch-Normalisierung). Wenn auf -1 gesetzt, wird die Kovarianzmatrix ohne Impuls aktualisiert.

Bei einem Batch - Input mit Form (batch_size, input_dim) , kehrt die GP Schicht eine logits Tensor (Form (batch_size, num_classes) ) für die Vorhersage und auch covmat Tensor (Form (batch_size, batch_size) ) , welches das hintere Kovarianzmatrix der Batch-Logs.

embedding = tf.random.normal(shape=(batch_size, input_dim))

logits, covmat = gp_layer(embedding)
2021-07-01 01:22:41.531478: I tensorflow/core/util/cuda_solvers.cc:180] Creating CudaSolver handles for stream 0x759dc90

Theoretisch ist es möglich , den Algorithmus zu erstrecken , unterschiedlichen Abweichungswerte für die verschiedenen Klassen zu berechnen (wie in dem eingeführten Original SNGP Papier ). Dies ist jedoch schwierig auf Probleme mit großen Ausgaberäumen (zB ImageNet oder Sprachmodellierung) zu skalieren.

Das vollständige SNGP-Modell

Angesichts der Basisklasse DeepResNet kann das SNGP Modell leicht implementiert werden , indem die Rest Netzwerks versteckt und Ausgabeschichten zu modifizieren. Für die Kompatibilität mit Keras model.fit() API, ändert auch das Modell call() Methode , so dass es nur Ausgänge logits während des Trainings.

class DeepResNetSNGP(DeepResNet):
  def __init__(self, spec_norm_bound=0.9, **kwargs):
    self.spec_norm_bound = spec_norm_bound
    super().__init__(**kwargs)

  def make_dense_layer(self):
    """Applies spectral normalization to the hidden layer."""
    dense_layer = super().make_dense_layer()
    return nlp_layers.SpectralNormalization(
        dense_layer, norm_multiplier=self.spec_norm_bound)

  def make_output_layer(self, num_classes):
    """Uses Gaussian process as the output layer."""
    return nlp_layers.RandomFeatureGaussianProcess(
        num_classes, 
        gp_cov_momentum=-1,
        **self.classifier_kwargs)

  def call(self, inputs, training=False, return_covmat=False):
    # Gets logits and covariance matrix from GP layer.
    logits, covmat = super().call(inputs)

    # Returns only logits during training.
    if not training and return_covmat:
      return logits, covmat

    return logits

Verwenden Sie dieselbe Architektur wie das deterministische Modell.

resnet_config
{'num_classes': 2, 'num_layers': 6, 'num_hidden': 128}
sngp_model = DeepResNetSNGP(**resnet_config)
sngp_model.build((None, 2))
sngp_model.summary()
Model: "deep_res_net_sngp"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              multiple                  384       
_________________________________________________________________
spectral_normalization_1 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_2 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_3 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_4 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_5 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_6 (Sp multiple                  16768     
_________________________________________________________________
random_feature_gaussian_proc multiple                  1182722   
=================================================================
Total params: 1,283,714
Trainable params: 101,120
Non-trainable params: 1,182,594
_________________________________________________________________

Implementieren Sie einen Keras-Callback, um die Kovarianzmatrix zu Beginn einer neuen Epoche zurückzusetzen.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()

Fügen Sie diesen Rückruf an die DeepResNetSNGP Modellklasse.

class DeepResNetSNGPWithCovReset(DeepResNetSNGP):
  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs["callbacks"] = list(kwargs.get("callbacks", []))
    kwargs["callbacks"].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Zugmodell

Verwenden Sie tf.keras.model.fit das Modell zu trainieren.

sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)
sngp_model.compile(**train_config)
sngp_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 2s 5ms/step - loss: 0.6460 - sparse_categorical_accuracy: 0.9015
Epoch 2/100
8/8 [==============================] - 0s 5ms/step - loss: 0.5550 - sparse_categorical_accuracy: 0.9960
Epoch 3/100
8/8 [==============================] - 0s 5ms/step - loss: 0.5056 - sparse_categorical_accuracy: 0.9980
Epoch 4/100
8/8 [==============================] - 0s 4ms/step - loss: 0.4636 - sparse_categorical_accuracy: 0.9980
Epoch 5/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4306 - sparse_categorical_accuracy: 0.9980
Epoch 6/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4043 - sparse_categorical_accuracy: 0.9970
Epoch 7/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3770 - sparse_categorical_accuracy: 0.9960
Epoch 8/100
8/8 [==============================] - 0s 5ms/step - loss: 0.3558 - sparse_categorical_accuracy: 0.9970
Epoch 9/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3391 - sparse_categorical_accuracy: 0.9980
Epoch 10/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3197 - sparse_categorical_accuracy: 0.9970
Epoch 11/100
8/8 [==============================] - 0s 5ms/step - loss: 0.3055 - sparse_categorical_accuracy: 0.9970
Epoch 12/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2900 - sparse_categorical_accuracy: 0.9960
Epoch 13/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2781 - sparse_categorical_accuracy: 0.9980
Epoch 14/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2675 - sparse_categorical_accuracy: 0.9980
Epoch 15/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2570 - sparse_categorical_accuracy: 0.9970
Epoch 16/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2461 - sparse_categorical_accuracy: 0.9970
Epoch 17/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2368 - sparse_categorical_accuracy: 0.9970
Epoch 18/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2278 - sparse_categorical_accuracy: 0.9980
Epoch 19/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2198 - sparse_categorical_accuracy: 0.9980
Epoch 20/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2118 - sparse_categorical_accuracy: 0.9980
Epoch 21/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2047 - sparse_categorical_accuracy: 0.9980
Epoch 22/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1979 - sparse_categorical_accuracy: 0.9980
Epoch 23/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1926 - sparse_categorical_accuracy: 0.9980
Epoch 24/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1864 - sparse_categorical_accuracy: 0.9980
Epoch 25/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1807 - sparse_categorical_accuracy: 0.9980
Epoch 26/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1754 - sparse_categorical_accuracy: 0.9970
Epoch 27/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1703 - sparse_categorical_accuracy: 0.9980
Epoch 28/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1648 - sparse_categorical_accuracy: 0.9980
Epoch 29/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1610 - sparse_categorical_accuracy: 0.9980
Epoch 30/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1552 - sparse_categorical_accuracy: 0.9980
Epoch 31/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1526 - sparse_categorical_accuracy: 0.9980
Epoch 32/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1492 - sparse_categorical_accuracy: 0.9970
Epoch 33/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1436 - sparse_categorical_accuracy: 0.9980
Epoch 34/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1400 - sparse_categorical_accuracy: 0.9980
Epoch 35/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1362 - sparse_categorical_accuracy: 0.9980
Epoch 36/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1330 - sparse_categorical_accuracy: 0.9980
Epoch 37/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1296 - sparse_categorical_accuracy: 0.9980
Epoch 38/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1271 - sparse_categorical_accuracy: 0.9980
Epoch 39/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1222 - sparse_categorical_accuracy: 0.9980
Epoch 40/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1207 - sparse_categorical_accuracy: 0.9980
Epoch 41/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1175 - sparse_categorical_accuracy: 0.9990
Epoch 42/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1155 - sparse_categorical_accuracy: 0.9980
Epoch 43/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1128 - sparse_categorical_accuracy: 0.9980
Epoch 44/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1093 - sparse_categorical_accuracy: 0.9990
Epoch 45/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1064 - sparse_categorical_accuracy: 0.9980
Epoch 46/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1052 - sparse_categorical_accuracy: 0.9980
Epoch 47/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1025 - sparse_categorical_accuracy: 0.9980
Epoch 48/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1009 - sparse_categorical_accuracy: 0.9980
Epoch 49/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0981 - sparse_categorical_accuracy: 0.9980
Epoch 50/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0967 - sparse_categorical_accuracy: 0.9990
Epoch 51/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0940 - sparse_categorical_accuracy: 0.9990
Epoch 52/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0921 - sparse_categorical_accuracy: 0.9980
Epoch 53/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0899 - sparse_categorical_accuracy: 0.9980
Epoch 54/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0892 - sparse_categorical_accuracy: 0.9990
Epoch 55/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0872 - sparse_categorical_accuracy: 0.9980
Epoch 56/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0836 - sparse_categorical_accuracy: 0.9990
Epoch 57/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0832 - sparse_categorical_accuracy: 0.9990
Epoch 58/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0821 - sparse_categorical_accuracy: 0.9990
Epoch 59/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0802 - sparse_categorical_accuracy: 0.9990
Epoch 60/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0783 - sparse_categorical_accuracy: 0.9990
Epoch 61/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0777 - sparse_categorical_accuracy: 1.0000
Epoch 62/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0768 - sparse_categorical_accuracy: 0.9990
Epoch 63/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0747 - sparse_categorical_accuracy: 0.9990
Epoch 64/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0737 - sparse_categorical_accuracy: 1.0000
Epoch 65/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0717 - sparse_categorical_accuracy: 0.9990
Epoch 66/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0706 - sparse_categorical_accuracy: 0.9990
Epoch 67/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0691 - sparse_categorical_accuracy: 0.9990
Epoch 68/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0685 - sparse_categorical_accuracy: 1.0000
Epoch 69/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0667 - sparse_categorical_accuracy: 1.0000
Epoch 70/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0666 - sparse_categorical_accuracy: 0.9990
Epoch 71/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0646 - sparse_categorical_accuracy: 1.0000
Epoch 72/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0634 - sparse_categorical_accuracy: 0.9980
Epoch 73/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0625 - sparse_categorical_accuracy: 1.0000
Epoch 74/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0614 - sparse_categorical_accuracy: 1.0000
Epoch 75/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0600 - sparse_categorical_accuracy: 1.0000
Epoch 76/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0597 - sparse_categorical_accuracy: 1.0000
Epoch 77/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0582 - sparse_categorical_accuracy: 1.0000
Epoch 78/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0579 - sparse_categorical_accuracy: 1.0000
Epoch 79/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0559 - sparse_categorical_accuracy: 1.0000
Epoch 80/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0556 - sparse_categorical_accuracy: 1.0000
Epoch 81/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0552 - sparse_categorical_accuracy: 1.0000
Epoch 82/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0543 - sparse_categorical_accuracy: 1.0000
Epoch 83/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0536 - sparse_categorical_accuracy: 1.0000
Epoch 84/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0525 - sparse_categorical_accuracy: 0.9990
Epoch 85/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0518 - sparse_categorical_accuracy: 1.0000
Epoch 86/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0506 - sparse_categorical_accuracy: 1.0000
Epoch 87/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0498 - sparse_categorical_accuracy: 0.9990
Epoch 88/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0497 - sparse_categorical_accuracy: 1.0000
Epoch 89/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0489 - sparse_categorical_accuracy: 1.0000
Epoch 90/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0480 - sparse_categorical_accuracy: 1.0000
Epoch 91/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0476 - sparse_categorical_accuracy: 1.0000
Epoch 92/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0460 - sparse_categorical_accuracy: 1.0000
Epoch 93/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0458 - sparse_categorical_accuracy: 1.0000
Epoch 94/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0453 - sparse_categorical_accuracy: 1.0000
Epoch 95/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0450 - sparse_categorical_accuracy: 1.0000
Epoch 96/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0446 - sparse_categorical_accuracy: 1.0000
Epoch 97/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0430 - sparse_categorical_accuracy: 1.0000
Epoch 98/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0428 - sparse_categorical_accuracy: 1.0000
Epoch 99/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0423 - sparse_categorical_accuracy: 1.0000
Epoch 100/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0416 - sparse_categorical_accuracy: 1.0000
<keras.callbacks.History at 0x7f02dc37ab90>

Unsicherheit visualisieren

Berechnen Sie zuerst die prädiktiven Logits und Varianzen.

sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_variance = tf.linalg.diag_part(sngp_covmat)[:, None]

Berechnen Sie nun die posteriore Vorhersagewahrscheinlichkeit. Die klassische Methode zur Berechnung der Vorhersagewahrscheinlichkeit eines probabilistischen Modells ist die Verwendung von Monte-Carlo-Sampling, d.

$$E(p(x)) = \frac{1}{M} \sum_{m=1}^M logit_m(x), $$

wobei $ M $ ist die Probengröße und $ logit_m (x) sind $ Stichproben aus dem SNGP posterior MultivariateNormal $ $ ( sngp_logits , sngp_covmat ). Dieser Ansatz kann jedoch für latenzempfindliche Anwendungen wie autonomes Fahren oder Echtzeitgebote langsam sein. Stattdessen kann $ annähert E (p (x)) $ , um die Verwendung von Mean-Field - Methode :

$$E(p(x)) \approx softmax(\frac{logit(x)}{\sqrt{1+ \lambda * \sigma^2(x)} })$$

wobei $\sigma^2(x)$ die SNGP-Varianz ist und $\lambda$ oft als $\pi/8$ oder $3/\pi^2$ gewählt wird.

sngp_logits_adjusted = sngp_logits / tf.sqrt(1. + (np.pi / 8.) * sngp_variance)
sngp_probs = tf.nn.softmax(sngp_logits_adjusted, axis=-1)[:, 0]

Dieser Mittelwert Feldmethode wird als integrierte Funktion implementiert layers.gaussian_process.mean_field_logits :

def compute_posterior_mean_probability(logits, covmat, lambda_param=np.pi / 8.):
  # Computes uncertainty-adjusted logits using the built-in method.
  logits_adjusted = nlp_layers.gaussian_process.mean_field_logits(
      logits, covmat, mean_field_factor=lambda_param)

  return tf.nn.softmax(logits_adjusted, axis=-1)[:, 0]
sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

SNGP-Zusammenfassung

Setzen Sie alles zusammen. Das gesamte Verfahren (Training, Bewertung und Unsicherheitsberechnung) kann in nur fünf Zeilen durchgeführt werden:

def train_and_test_sngp(train_examples, test_examples):
  sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)

  sngp_model.compile(**train_config)
  sngp_model.fit(train_examples, train_labels, verbose=0, **fit_config)

  sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
  sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

  return sngp_probs
sngp_probs = train_and_test_sngp(train_examples, test_examples)

Visualisieren Sie die Klassenwahrscheinlichkeit (links) und die Vorhersageunsicherheit (rechts) des SNGP-Modells.

plot_predictions(sngp_probs, model_name="SNGP")

png

Denken Sie daran, dass im Klassenwahrscheinlichkeitsnetz (links) das Gelb und das Violett Klassenwahrscheinlichkeiten sind. In der Nähe der Trainingsdatendomäne klassifiziert SNGP die Beispiele korrekt mit hoher Zuverlässigkeit (dh Zuweisung einer Wahrscheinlichkeit von nahe 0 oder 1). Weit weg von den Trainingsdaten wird SNGP allmählich weniger sicher und seine Vorhersagewahrscheinlichkeit nähert sich 0,5, während die (normalisierte) Modellunsicherheit auf 1 ansteigt.

Vergleichen Sie dies mit der Unsicherheitsfläche des deterministischen Modells:

plot_predictions(resnet_probs, model_name="Deterministic")

png

Wie bereits erwähnt, ein deterministisches Modell ist nicht entfernungs bewusst. Seine Unsicherheit wird durch den Abstand des Testbeispiels von der Entscheidungsgrenze definiert. Dies führt dazu, dass das Modell zu selbstsichere Vorhersagen für die Beispiele außerhalb der Domäne (rot) erstellt.

Vergleich mit anderen Unsicherheitsansätzen

In diesem Abschnitt wird die Unsicherheit von SNGP mit Monte - Carlo - Dropout und Tief Ensemble .

Beide Verfahren basieren auf der Monte-Carlo-Mittelung mehrerer Vorwärtsdurchläufe deterministischer Modelle. Legen Sie zunächst die Ensemblegröße $M$ fest.

num_ensemble = 10

Monte-Carlo-Aussteiger

Bei einem trainierten neuronalen Netzwerk mit Dropout Schichten, Dropout Monte Carlo berechnet die Durchschnittsvorhersagewahrscheinlichkeit

$$E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))$$

durch Mittelung über mehrere Dropout-aktivierte Vorwärtsdurchgänge ${logit_m(x)}_{m=1}^M$.

def mc_dropout_sampling(test_examples):
  # Enable dropout during inference.
  return resnet_model(test_examples, training=True)
# Monte Carlo dropout inference.
dropout_logit_samples = [mc_dropout_sampling(test_examples) for _ in range(num_ensemble)]
dropout_prob_samples = [tf.nn.softmax(dropout_logits, axis=-1)[:, 0] for dropout_logits in dropout_logit_samples]
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
plot_predictions(dropout_probs, model_name="MC Dropout")

png

Tiefes Ensemble

Tief Ensemble ist eine state-of-the-art (aber teure) Methode für tiefe Lernen Unsicherheit. Um ein Deep-Ensemble zu trainieren, trainieren Sie zuerst $M$-Ensemble-Mitglieder.

# Deep ensemble training
resnet_ensemble = []
for _ in range(num_ensemble):
  resnet_model = DeepResNet(**resnet_config)
  resnet_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
  resnet_model.fit(train_examples, train_labels, verbose=0, **fit_config)  

  resnet_ensemble.append(resnet_model)

Sammle Logits und berechne die mittlere Vorhersagewahrscheinlichkeit $E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))$.

# Deep ensemble inference
ensemble_logit_samples = [model(test_examples) for model in resnet_ensemble]
ensemble_prob_samples = [tf.nn.softmax(logits, axis=-1)[:, 0] for logits in ensemble_logit_samples]
ensemble_probs = tf.reduce_mean(ensemble_prob_samples, axis=0)
plot_predictions(ensemble_probs, model_name="Deep ensemble")

png

Sowohl MC Dropout als auch Deep Ensemble verbessern die Unsicherheitsfähigkeit eines Modells, indem sie die Entscheidungsgrenze weniger sicher machen. Beide erben jedoch die Einschränkung des deterministischen Tiefennetzwerks in Bezug auf das fehlende Distanzbewusstsein.

Zusammenfassung

In diesem Tutorial haben Sie:

  • Implementierung eines SNGP-Modells für einen tiefen Klassifikator, um sein Distanzbewusstsein zu verbessern.
  • Ausgebildet das SNGP Modell Ende-zu-Ende unter Verwendung von Keras model.fit() API.
  • Visualisierung des Unsicherheitsverhaltens von SNGP.
  • Verglichen das Unsicherheitsverhalten zwischen SNGP, Monte Carlo Dropout und Deep Ensemble Modellen.

Ressourcen und weiterführende Literatur