Głębokie uczenie ze świadomością niepewności z SNGP

Zobacz na TensorFlow.org Uruchom w Google Colab Zobacz na GitHub Pobierz notatnik

W zastosowaniach AI, które mają kluczowe znaczenie dla bezpieczeństwa (np. podejmowanie decyzji medycznych i autonomiczna jazda) lub gdy dane są z natury zaszumione (np. rozumienie języka naturalnego), ważne jest, aby głęboki klasyfikator mógł wiarygodnie określić swoją niepewność. Głęboki klasyfikator powinien być świadomy swoich własnych ograniczeń i tego, kiedy powinien przekazać kontrolę ludzkim ekspertom. Ten samouczek pokazuje, jak poprawić zdolność głębokiego klasyfikatora do ilościowego określania niepewności przy użyciu techniki zwanej spektralnie znormalizowanym procesem neuronowo-gaussowskim ( SNGP ) .

Główną ideą SNGP jest poprawa świadomości odległości klasyfikatora poprzez zastosowanie prostych modyfikacji w sieci. Świadomość odległości modelu jest miarą tego, jak jego prawdopodobieństwo predykcyjne odzwierciedla odległość między przykładem testowym a danymi uczącymi. Jest to pożądana właściwość, która jest powszechna w modelach probablistycznych o złotym standardzie (np. proces Gaussa z jądrami RBF), ale brakuje jej w modelach z głębokimi sieciami neuronowymi. SNGP zapewnia prosty sposób na wprowadzenie tego zachowania procesu Gaussa do głębokiego klasyfikatora przy jednoczesnym zachowaniu jego dokładności predykcyjnej.

Ten samouczek implementuje model SNGP oparty na głębokiej sieci szczątkowej (ResNet) na zestawie danych dwóch księżyców i porównuje jego powierzchnię niepewności z dwoma innymi popularnymi podejściami do niepewności - zanikami Monte Carlo i Deep ensemble ).

Ten samouczek ilustruje model SNGP na zabawkowym zestawie danych 2D. Aby zapoznać się z przykładem zastosowania SNGP do rzeczywistego zadania rozumienia języka naturalnego przy użyciu bazy BERT, zobacz samouczek SNGP-BERT . Aby uzyskać wysokiej jakości implementacje modelu SNGP (i wielu innych metod niepewności) na szerokiej gamie zestawów danych porównawczych (np. CIFAR-100 , ImageNet , wykrywanie toksyczności Jigsaw itp.), sprawdź benchmark Uncertainty Baselines .

O SNGP

Znormalizowany spektralnie proces neuronowo-gaussowski (SNGP) to proste podejście do poprawy jakości niepewności głębokiego klasyfikatora przy zachowaniu podobnego poziomu dokładności i opóźnienia. Biorąc pod uwagę głęboką sieć rezydualną, SNGP wprowadza dwie proste zmiany w modelu:

  • Stosuje normalizację widmową do ukrytych warstw resztkowych.
  • Zastępuje warstwę wyjściową Dense warstwą procesu Gaussa.

SNGP

W porównaniu z innymi podejściami do niepewności (np. odpadanie Monte Carlo lub zespół Deep), SNGP ma kilka zalet:

  • Działa z szeroką gamą najnowocześniejszych architektur opartych na szczątkach (np. (Wide) ResNet, DenseNet, BERT itp.).
  • Jest to metoda jednomodelowa (tj. nie polega na uśrednianiu zbiorowym). Dlatego SNGP ma podobny poziom opóźnień jak pojedyncza sieć deterministyczna i można go łatwo skalować do dużych zestawów danych, takich jak klasyfikacja ImageNet i Jigsaw Toxic Comments .
  • Ma wysoką skuteczność wykrywania poza domeną ze względu na właściwość rozpoznawania odległości .

Wady tej metody to:

  • Niepewność predykcyjna SNGP jest obliczana przy użyciu przybliżenia Laplace'a . Dlatego teoretycznie niepewność a posteriori SNGP różni się od niepewności dokładnego procesu Gaussa.

  • Trening SNGP wymaga kroku resetu kowariancji na początku nowej epoki. Może to dodać odrobinę dodatkowej złożoności do potoku szkoleniowego. Ten samouczek pokazuje prosty sposób na zaimplementowanie tego za pomocą wywołań zwrotnych Keras.

Ustawiać

pip install --use-deprecated=legacy-resolver tf-models-official
# refresh pkg_resources so it takes the changes into account.
import pkg_resources
import importlib
importlib.reload(pkg_resources)
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py'>
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import sklearn.datasets

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as nlp_layers

Zdefiniuj makra wizualizacji

plt.rcParams['figure.dpi'] = 140

DEFAULT_X_RANGE = (-3.5, 3.5)
DEFAULT_Y_RANGE = (-2.5, 2.5)
DEFAULT_CMAP = colors.ListedColormap(["#377eb8", "#ff7f00"])
DEFAULT_NORM = colors.Normalize(vmin=0, vmax=1,)
DEFAULT_N_GRID = 100

Zbiór danych dwóch księżyców

Utwórz zestawy danych szkoleniowych i ewaluacyjnych na podstawie zestawu danych dwóch księżyców .

def make_training_data(sample_size=500):
  """Create two moon training dataset."""
  train_examples, train_labels = sklearn.datasets.make_moons(
      n_samples=2 * sample_size, noise=0.1)

  # Adjust data position slightly.
  train_examples[train_labels == 0] += [-0.1, 0.2]
  train_examples[train_labels == 1] += [0.1, -0.2]

  return train_examples, train_labels

Oceń predykcyjne zachowanie modelu w całej przestrzeni wejściowej 2D.

def make_testing_data(x_range=DEFAULT_X_RANGE, y_range=DEFAULT_Y_RANGE, n_grid=DEFAULT_N_GRID):
  """Create a mesh grid in 2D space."""
  # testing data (mesh grid over data space)
  x = np.linspace(x_range[0], x_range[1], n_grid)
  y = np.linspace(y_range[0], y_range[1], n_grid)
  xv, yv = np.meshgrid(x, y)
  return np.stack([xv.flatten(), yv.flatten()], axis=-1)

Aby ocenić niepewność modelu, dodaj zestaw danych spoza domeny (OOD), który należy do trzeciej klasy. Model nigdy nie widzi tych przykładów OOD podczas treningu.

def make_ood_data(sample_size=500, means=(2.5, -1.75), vars=(0.01, 0.01)):
  return np.random.multivariate_normal(
      means, cov=np.diag(vars), size=sample_size)
# Load the train, test and OOD datasets.
train_examples, train_labels = make_training_data(
    sample_size=500)
test_examples = make_testing_data()
ood_examples = make_ood_data(sample_size=500)

# Visualize
pos_examples = train_examples[train_labels == 0]
neg_examples = train_examples[train_labels == 1]

plt.figure(figsize=(7, 5.5))

plt.scatter(pos_examples[:, 0], pos_examples[:, 1], c="#377eb8", alpha=0.5)
plt.scatter(neg_examples[:, 0], neg_examples[:, 1], c="#ff7f00", alpha=0.5)
plt.scatter(ood_examples[:, 0], ood_examples[:, 1], c="red", alpha=0.1)

plt.legend(["Postive", "Negative", "Out-of-Domain"])

plt.ylim(DEFAULT_Y_RANGE)
plt.xlim(DEFAULT_X_RANGE)

plt.show()

png

Tutaj niebieski i pomarańczowy reprezentują klasy pozytywne i negatywne, a czerwony reprezentuje dane OOD. Oczekuje się, że model, który dobrze mierzy niepewność, będzie pewny, gdy znajduje się w pobliżu danych uczących (tj. \(p(x_{test})\) blisko 0 lub 1) i będzie niepewny, gdy będzie daleko od obszarów danych uczących (tj. \(p(x_{test})\) blisko 0,5 ).

Model deterministyczny

Zdefiniuj model

Zacznij od (bazowego) modelu deterministycznego: wielowarstwowej sieci resztkowej (ResNet) z regularyzacją przerywania.

Ten samouczek wykorzystuje 6-warstwowy ResNet ze 128 ukrytymi jednostkami.

resnet_config = dict(num_classes=2, num_layers=6, num_hidden=128)
resnet_model = DeepResNet(**resnet_config)
resnet_model.build((None, 2))
resnet_model.summary()
Model: "deep_res_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               multiple                  384       
                                                                 
 dense_1 (Dense)             multiple                  16512     
                                                                 
 dense_2 (Dense)             multiple                  16512     
                                                                 
 dense_3 (Dense)             multiple                  16512     
                                                                 
 dense_4 (Dense)             multiple                  16512     
                                                                 
 dense_5 (Dense)             multiple                  16512     
                                                                 
 dense_6 (Dense)             multiple                  16512     
                                                                 
 dense_7 (Dense)             multiple                  258       
                                                                 
=================================================================
Total params: 99,714
Trainable params: 99,330
Non-trainable params: 384
_________________________________________________________________

Model pociągu

Skonfiguruj parametry uczenia, aby używać SparseCategoricalCrossentropy jako funkcji straty i optymalizatora Adama.

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.keras.metrics.SparseCategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

train_config = dict(loss=loss, metrics=metrics, optimizer=optimizer)

Trenuj model przez 100 epok przy wielkości partii 128.

fit_config = dict(batch_size=128, epochs=100)
resnet_model.compile(**train_config)
resnet_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 1s 4ms/step - loss: 1.1251 - sparse_categorical_accuracy: 0.5050
Epoch 2/100
8/8 [==============================] - 0s 3ms/step - loss: 0.5538 - sparse_categorical_accuracy: 0.6920
Epoch 3/100
8/8 [==============================] - 0s 3ms/step - loss: 0.2881 - sparse_categorical_accuracy: 0.9160
Epoch 4/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1923 - sparse_categorical_accuracy: 0.9370
Epoch 5/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1550 - sparse_categorical_accuracy: 0.9420
Epoch 6/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1403 - sparse_categorical_accuracy: 0.9450
Epoch 7/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1269 - sparse_categorical_accuracy: 0.9430
Epoch 8/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1208 - sparse_categorical_accuracy: 0.9460
Epoch 9/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1158 - sparse_categorical_accuracy: 0.9510
Epoch 10/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1103 - sparse_categorical_accuracy: 0.9490
Epoch 11/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1051 - sparse_categorical_accuracy: 0.9510
Epoch 12/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1053 - sparse_categorical_accuracy: 0.9510
Epoch 13/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1013 - sparse_categorical_accuracy: 0.9450
Epoch 14/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0967 - sparse_categorical_accuracy: 0.9500
Epoch 15/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0991 - sparse_categorical_accuracy: 0.9530
Epoch 16/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0984 - sparse_categorical_accuracy: 0.9500
Epoch 17/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0982 - sparse_categorical_accuracy: 0.9480
Epoch 18/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0918 - sparse_categorical_accuracy: 0.9510
Epoch 19/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0903 - sparse_categorical_accuracy: 0.9500
Epoch 20/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0883 - sparse_categorical_accuracy: 0.9510
Epoch 21/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0870 - sparse_categorical_accuracy: 0.9530
Epoch 22/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0884 - sparse_categorical_accuracy: 0.9560
Epoch 23/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0850 - sparse_categorical_accuracy: 0.9540
Epoch 24/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0808 - sparse_categorical_accuracy: 0.9580
Epoch 25/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0773 - sparse_categorical_accuracy: 0.9560
Epoch 26/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0801 - sparse_categorical_accuracy: 0.9590
Epoch 27/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0779 - sparse_categorical_accuracy: 0.9580
Epoch 28/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0807 - sparse_categorical_accuracy: 0.9580
Epoch 29/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0820 - sparse_categorical_accuracy: 0.9570
Epoch 30/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0730 - sparse_categorical_accuracy: 0.9600
Epoch 31/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0782 - sparse_categorical_accuracy: 0.9590
Epoch 32/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0704 - sparse_categorical_accuracy: 0.9600
Epoch 33/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0709 - sparse_categorical_accuracy: 0.9610
Epoch 34/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9580
Epoch 35/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0702 - sparse_categorical_accuracy: 0.9610
Epoch 36/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0688 - sparse_categorical_accuracy: 0.9600
Epoch 37/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0675 - sparse_categorical_accuracy: 0.9630
Epoch 38/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0636 - sparse_categorical_accuracy: 0.9690
Epoch 39/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0677 - sparse_categorical_accuracy: 0.9610
Epoch 40/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0702 - sparse_categorical_accuracy: 0.9650
Epoch 41/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0614 - sparse_categorical_accuracy: 0.9690
Epoch 42/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0663 - sparse_categorical_accuracy: 0.9680
Epoch 43/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0626 - sparse_categorical_accuracy: 0.9740
Epoch 44/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0590 - sparse_categorical_accuracy: 0.9760
Epoch 45/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0573 - sparse_categorical_accuracy: 0.9780
Epoch 46/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0568 - sparse_categorical_accuracy: 0.9770
Epoch 47/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0595 - sparse_categorical_accuracy: 0.9780
Epoch 48/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0482 - sparse_categorical_accuracy: 0.9840
Epoch 49/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0515 - sparse_categorical_accuracy: 0.9820
Epoch 50/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0525 - sparse_categorical_accuracy: 0.9830
Epoch 51/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0507 - sparse_categorical_accuracy: 0.9790
Epoch 52/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0433 - sparse_categorical_accuracy: 0.9850
Epoch 53/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0511 - sparse_categorical_accuracy: 0.9820
Epoch 54/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0501 - sparse_categorical_accuracy: 0.9820
Epoch 55/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0440 - sparse_categorical_accuracy: 0.9890
Epoch 56/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0438 - sparse_categorical_accuracy: 0.9850
Epoch 57/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0438 - sparse_categorical_accuracy: 0.9880
Epoch 58/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0416 - sparse_categorical_accuracy: 0.9860
Epoch 59/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0479 - sparse_categorical_accuracy: 0.9860
Epoch 60/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0434 - sparse_categorical_accuracy: 0.9860
Epoch 61/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0414 - sparse_categorical_accuracy: 0.9880
Epoch 62/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0402 - sparse_categorical_accuracy: 0.9870
Epoch 63/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0376 - sparse_categorical_accuracy: 0.9890
Epoch 64/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0337 - sparse_categorical_accuracy: 0.9900
Epoch 65/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0309 - sparse_categorical_accuracy: 0.9910
Epoch 66/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0336 - sparse_categorical_accuracy: 0.9910
Epoch 67/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0389 - sparse_categorical_accuracy: 0.9870
Epoch 68/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9920
Epoch 69/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0331 - sparse_categorical_accuracy: 0.9890
Epoch 70/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0346 - sparse_categorical_accuracy: 0.9900
Epoch 71/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0367 - sparse_categorical_accuracy: 0.9880
Epoch 72/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0283 - sparse_categorical_accuracy: 0.9920
Epoch 73/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0315 - sparse_categorical_accuracy: 0.9930
Epoch 74/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0271 - sparse_categorical_accuracy: 0.9900
Epoch 75/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0257 - sparse_categorical_accuracy: 0.9920
Epoch 76/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0289 - sparse_categorical_accuracy: 0.9900
Epoch 77/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0264 - sparse_categorical_accuracy: 0.9900
Epoch 78/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0272 - sparse_categorical_accuracy: 0.9910
Epoch 79/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0336 - sparse_categorical_accuracy: 0.9880
Epoch 80/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0249 - sparse_categorical_accuracy: 0.9900
Epoch 81/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0216 - sparse_categorical_accuracy: 0.9930
Epoch 82/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0279 - sparse_categorical_accuracy: 0.9890
Epoch 83/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0261 - sparse_categorical_accuracy: 0.9920
Epoch 84/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0235 - sparse_categorical_accuracy: 0.9920
Epoch 85/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0236 - sparse_categorical_accuracy: 0.9930
Epoch 86/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0219 - sparse_categorical_accuracy: 0.9920
Epoch 87/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9920
Epoch 88/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0215 - sparse_categorical_accuracy: 0.9900
Epoch 89/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9900
Epoch 90/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9950
Epoch 91/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0250 - sparse_categorical_accuracy: 0.9900
Epoch 92/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0160 - sparse_categorical_accuracy: 0.9940
Epoch 93/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0203 - sparse_categorical_accuracy: 0.9930
Epoch 94/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0203 - sparse_categorical_accuracy: 0.9930
Epoch 95/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0172 - sparse_categorical_accuracy: 0.9960
Epoch 96/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0209 - sparse_categorical_accuracy: 0.9940
Epoch 97/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0179 - sparse_categorical_accuracy: 0.9920
Epoch 98/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0195 - sparse_categorical_accuracy: 0.9940
Epoch 99/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0165 - sparse_categorical_accuracy: 0.9930
Epoch 100/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0170 - sparse_categorical_accuracy: 0.9950
<keras.callbacks.History at 0x7ff7ac5c8fd0>

Wizualizuj niepewność

Teraz zwizualizuj przewidywania modelu deterministycznego. Najpierw wykreśl prawdopodobieństwo klasy:

\[p(x) = softmax(logit(x))\]

resnet_logits = resnet_model(test_examples)
resnet_probs = tf.nn.softmax(resnet_logits, axis=-1)[:, 0]  # Take the probability for class 0.
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_probs, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Class Probability, Deterministic Model")

plt.show()

png

Na tym wykresie żółty i fioletowy są prawdopodobieństwami predykcyjnymi dla dwóch klas. Model deterministyczny wykonał dobrą robotę przy klasyfikowaniu dwóch znanych klas (niebieskiej i pomarańczowej) z nieliniową granicą decyzyjną. Jednak nie jest świadomy odległości i z pewnością zaklasyfikował nigdy nie widziane przykłady czerwonej domeny poza domeną (OOD) jako klasę pomarańczową.

Wizualizuj niepewność modelu, obliczając wariancję predykcyjną :

\[var(x) = p(x) * (1 - p(x))\]

resnet_uncertainty = resnet_probs * (1 - resnet_probs)
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_uncertainty, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Predictive Uncertainty, Deterministic Model")

plt.show()

png

Na tym wykresie kolor żółty oznacza wysoką niepewność, a kolor fioletowy oznacza niską niepewność. Deterministyczna niepewność ResNet zależy tylko od odległości przykładów testowych od granicy decyzji. Prowadzi to do tego, że model jest zbyt pewny siebie, gdy znajduje się poza domeną szkoleniową. W następnej sekcji pokazano, jak protokół SNGP zachowuje się inaczej w tym zestawie danych.

Model SNGP

Zdefiniuj model SNGP

Zaimplementujmy teraz model SNGP. Oba komponenty SNGP, SpectralNormalization i RandomFeatureGaussianProcess , są dostępne na wbudowanych warstwach tensorflow_model .

SNGP

Przyjrzyjmy się tym dwóm komponentom bardziej szczegółowo. (Możesz również przejść do sekcji Model SNGP, aby zobaczyć, jak zaimplementowany jest pełny model).

Opakowanie normalizacji widmowej

SpectralNormalization to opakowanie warstwy Keras. Można go zastosować do istniejącej warstwy gęstej w następujący sposób:

dense = tf.keras.layers.Dense(units=10)
dense = nlp_layers.SpectralNormalization(dense, norm_multiplier=0.9)

Normalizacja widmowa reguluje ukrytą wagę \(W\) , stopniowo kierując jej normę widmową (tj. największą wartość własną \(W\)) w kierunku wartości docelowej norm_multiplier .

Warstwa procesu Gaussa (GP)

RandomFeatureGaussianProcess implementuje aproksymację opartą na funkcjach losowych do modelu procesu Gaussa, który można trenować od końca do końca za pomocą głębokiej sieci neuronowej. Pod maską warstwa procesu Gaussa implementuje sieć dwuwarstwową:

\[logits(x) = \Phi(x) \beta, \quad \Phi(x)=\sqrt{\frac{2}{M} } * cos(Wx + b)\]

Tutaj \(x\) jest danymi wejściowymi, a \(W\) i \(b\) są zamrożonymi wagami zainicjowanymi losowo odpowiednio z rozkładu Gaussa i rozkładu jednolitego. (Dlatego \(\Phi(x)\) są nazywane „funkcjami losowymi”.) \(\beta\) to waga jądra, której można się nauczyć, podobnie jak w przypadku warstwy Dense.

batch_size = 32
input_dim = 1024
num_classes = 10
gp_layer = nlp_layers.RandomFeatureGaussianProcess(units=num_classes,
                                               num_inducing=1024,
                                               normalize_input=False,
                                               scale_random_features=True,
                                               gp_cov_momentum=-1)

Główne parametry warstw GP to:

  • units : Wymiar logów wyjściowych.
  • num_inducing : wymiar \(M\) ukrytej wagi \(W\). Domyślnie 1024.
  • normalize_input : Czy zastosować normalizację warstw do danych wejściowych \(x\).
  • scale_random_features : czy zastosować skalę \(\sqrt{2/M}\) do ukrytego wyjścia.
  • gp_cov_momentum kontroluje sposób obliczania kowariancji modelu. Jeśli jest ustawiona na wartość dodatnią (np. 0,999), macierz kowariancji jest obliczana przy użyciu aktualizacji średniej ruchomej opartej na pędzie (podobnie jak normalizacja wsadowa). Jeśli ustawiono na -1, macierz kowariancji jest aktualizowana bez pędu.

Biorąc pod uwagę dane wejściowe wsadowe o kształcie (batch_size, input_dim) , warstwa GP zwraca tensor logits (shape (batch_size, num_classes) ) do przewidywania, a także tensor covmat (shape (batch_size, batch_size) ), który jest macierzą kowariancji a posteriori logi wsadowe.

embedding = tf.random.normal(shape=(batch_size, input_dim))

logits, covmat = gp_layer(embedding)

Teoretycznie możliwe jest rozszerzenie algorytmu o obliczanie różnych wartości wariancji dla różnych klas (jak przedstawiono w oryginalnym artykule SNGP ). Jest to jednak trudne do skalowania do problemów z dużymi przestrzeniami wyjściowymi (np. ImageNet lub modelowanie językowe).

Pełny model SNGP

Biorąc pod uwagę klasę bazową DeepResNet , model SNGP można łatwo zaimplementować, modyfikując warstwy ukryte i wyjściowe sieci szczątkowej. Aby zapewnić zgodność z interfejsem API Keras model.fit() , zmodyfikuj również metodę call() modelu, aby podczas uczenia wyświetlała ona tylko logits .

class DeepResNetSNGP(DeepResNet):
  def __init__(self, spec_norm_bound=0.9, **kwargs):
    self.spec_norm_bound = spec_norm_bound
    super().__init__(**kwargs)

  def make_dense_layer(self):
    """Applies spectral normalization to the hidden layer."""
    dense_layer = super().make_dense_layer()
    return nlp_layers.SpectralNormalization(
        dense_layer, norm_multiplier=self.spec_norm_bound)

  def make_output_layer(self, num_classes):
    """Uses Gaussian process as the output layer."""
    return nlp_layers.RandomFeatureGaussianProcess(
        num_classes, 
        gp_cov_momentum=-1,
        **self.classifier_kwargs)

  def call(self, inputs, training=False, return_covmat=False):
    # Gets logits and covariance matrix from GP layer.
    logits, covmat = super().call(inputs)

    # Returns only logits during training.
    if not training and return_covmat:
      return logits, covmat

    return logits

Użyj tej samej architektury, co model deterministyczny.

resnet_config
{'num_classes': 2, 'num_layers': 6, 'num_hidden': 128}
sngp_model = DeepResNetSNGP(**resnet_config)
sngp_model.build((None, 2))
sngp_model.summary()
Model: "deep_res_net_sngp"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_9 (Dense)             multiple                  384       
                                                                 
 spectral_normalization_1 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_2 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_3 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_4 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_5 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_6 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 random_feature_gaussian_pro  multiple                 1182722   
 cess (RandomFeatureGaussian                                     
 Process)                                                        
                                                                 
=================================================================
Total params: 1,283,714
Trainable params: 101,120
Non-trainable params: 1,182,594
_________________________________________________________________

Zaimplementuj wywołanie zwrotne Keras, aby zresetować macierz kowariancji na początku nowej epoki.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()

Dodaj to wywołanie zwrotne do klasy modelu DeepResNetSNGP .

class DeepResNetSNGPWithCovReset(DeepResNetSNGP):
  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs["callbacks"] = list(kwargs.get("callbacks", []))
    kwargs["callbacks"].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Model pociągu

Użyj tf.keras.model.fit do trenowania modelu.

sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)
sngp_model.compile(**train_config)
sngp_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 2s 5ms/step - loss: 0.6223 - sparse_categorical_accuracy: 0.9570
Epoch 2/100
8/8 [==============================] - 0s 4ms/step - loss: 0.5310 - sparse_categorical_accuracy: 0.9980
Epoch 3/100
8/8 [==============================] - 0s 4ms/step - loss: 0.4766 - sparse_categorical_accuracy: 0.9990
Epoch 4/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4346 - sparse_categorical_accuracy: 0.9980
Epoch 5/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4015 - sparse_categorical_accuracy: 0.9980
Epoch 6/100
8/8 [==============================] - 0s 5ms/step - loss: 0.3757 - sparse_categorical_accuracy: 0.9990
Epoch 7/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3525 - sparse_categorical_accuracy: 0.9990
Epoch 8/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3305 - sparse_categorical_accuracy: 0.9990
Epoch 9/100
8/8 [==============================] - 0s 5ms/step - loss: 0.3144 - sparse_categorical_accuracy: 0.9980
Epoch 10/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2975 - sparse_categorical_accuracy: 0.9990
Epoch 11/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2832 - sparse_categorical_accuracy: 0.9990
Epoch 12/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2707 - sparse_categorical_accuracy: 0.9990
Epoch 13/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2568 - sparse_categorical_accuracy: 0.9990
Epoch 14/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2470 - sparse_categorical_accuracy: 0.9970
Epoch 15/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2361 - sparse_categorical_accuracy: 0.9990
Epoch 16/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2271 - sparse_categorical_accuracy: 0.9990
Epoch 17/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2182 - sparse_categorical_accuracy: 0.9990
Epoch 18/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2097 - sparse_categorical_accuracy: 0.9990
Epoch 19/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2018 - sparse_categorical_accuracy: 0.9990
Epoch 20/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1940 - sparse_categorical_accuracy: 0.9980
Epoch 21/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1892 - sparse_categorical_accuracy: 0.9990
Epoch 22/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1821 - sparse_categorical_accuracy: 0.9980
Epoch 23/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1768 - sparse_categorical_accuracy: 0.9990
Epoch 24/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1702 - sparse_categorical_accuracy: 0.9980
Epoch 25/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1664 - sparse_categorical_accuracy: 0.9990
Epoch 26/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1604 - sparse_categorical_accuracy: 0.9990
Epoch 27/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1565 - sparse_categorical_accuracy: 0.9990
Epoch 28/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1517 - sparse_categorical_accuracy: 0.9990
Epoch 29/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1469 - sparse_categorical_accuracy: 0.9990
Epoch 30/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1431 - sparse_categorical_accuracy: 0.9980
Epoch 31/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1385 - sparse_categorical_accuracy: 0.9980
Epoch 32/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1351 - sparse_categorical_accuracy: 0.9990
Epoch 33/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1312 - sparse_categorical_accuracy: 0.9980
Epoch 34/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1289 - sparse_categorical_accuracy: 0.9990
Epoch 35/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1254 - sparse_categorical_accuracy: 0.9980
Epoch 36/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1223 - sparse_categorical_accuracy: 0.9980
Epoch 37/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1180 - sparse_categorical_accuracy: 0.9990
Epoch 38/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1167 - sparse_categorical_accuracy: 0.9990
Epoch 39/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1132 - sparse_categorical_accuracy: 0.9980
Epoch 40/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1110 - sparse_categorical_accuracy: 0.9990
Epoch 41/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1075 - sparse_categorical_accuracy: 0.9990
Epoch 42/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1067 - sparse_categorical_accuracy: 0.9990
Epoch 43/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1034 - sparse_categorical_accuracy: 0.9990
Epoch 44/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1006 - sparse_categorical_accuracy: 0.9990
Epoch 45/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0991 - sparse_categorical_accuracy: 0.9990
Epoch 46/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0963 - sparse_categorical_accuracy: 0.9990
Epoch 47/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0943 - sparse_categorical_accuracy: 0.9980
Epoch 48/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0925 - sparse_categorical_accuracy: 0.9990
Epoch 49/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0905 - sparse_categorical_accuracy: 0.9990
Epoch 50/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0889 - sparse_categorical_accuracy: 0.9990
Epoch 51/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0863 - sparse_categorical_accuracy: 0.9980
Epoch 52/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0847 - sparse_categorical_accuracy: 0.9990
Epoch 53/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0831 - sparse_categorical_accuracy: 0.9980
Epoch 54/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0818 - sparse_categorical_accuracy: 0.9990
Epoch 55/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0799 - sparse_categorical_accuracy: 0.9990
Epoch 56/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0780 - sparse_categorical_accuracy: 0.9990
Epoch 57/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0768 - sparse_categorical_accuracy: 0.9990
Epoch 58/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0751 - sparse_categorical_accuracy: 0.9990
Epoch 59/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9990
Epoch 60/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0723 - sparse_categorical_accuracy: 0.9990
Epoch 61/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0712 - sparse_categorical_accuracy: 0.9990
Epoch 62/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0701 - sparse_categorical_accuracy: 0.9990
Epoch 63/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0701 - sparse_categorical_accuracy: 0.9990
Epoch 64/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0683 - sparse_categorical_accuracy: 0.9990
Epoch 65/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0665 - sparse_categorical_accuracy: 0.9990
Epoch 66/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0661 - sparse_categorical_accuracy: 0.9990
Epoch 67/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0636 - sparse_categorical_accuracy: 0.9990
Epoch 68/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0631 - sparse_categorical_accuracy: 0.9990
Epoch 69/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0620 - sparse_categorical_accuracy: 0.9990
Epoch 70/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0606 - sparse_categorical_accuracy: 0.9990
Epoch 71/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0601 - sparse_categorical_accuracy: 0.9980
Epoch 72/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0590 - sparse_categorical_accuracy: 0.9990
Epoch 73/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0586 - sparse_categorical_accuracy: 0.9990
Epoch 74/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0574 - sparse_categorical_accuracy: 0.9990
Epoch 75/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0565 - sparse_categorical_accuracy: 1.0000
Epoch 76/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0559 - sparse_categorical_accuracy: 0.9990
Epoch 77/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0549 - sparse_categorical_accuracy: 0.9990
Epoch 78/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0534 - sparse_categorical_accuracy: 1.0000
Epoch 79/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0532 - sparse_categorical_accuracy: 0.9990
Epoch 80/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0519 - sparse_categorical_accuracy: 1.0000
Epoch 81/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0511 - sparse_categorical_accuracy: 1.0000
Epoch 82/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0508 - sparse_categorical_accuracy: 0.9990
Epoch 83/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0499 - sparse_categorical_accuracy: 1.0000
Epoch 84/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0490 - sparse_categorical_accuracy: 1.0000
Epoch 85/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0490 - sparse_categorical_accuracy: 0.9990
Epoch 86/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0470 - sparse_categorical_accuracy: 1.0000
Epoch 87/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0468 - sparse_categorical_accuracy: 1.0000
Epoch 88/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0468 - sparse_categorical_accuracy: 1.0000
Epoch 89/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0453 - sparse_categorical_accuracy: 1.0000
Epoch 90/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0448 - sparse_categorical_accuracy: 1.0000
Epoch 91/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0441 - sparse_categorical_accuracy: 1.0000
Epoch 92/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0434 - sparse_categorical_accuracy: 1.0000
Epoch 93/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0431 - sparse_categorical_accuracy: 1.0000
Epoch 94/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0424 - sparse_categorical_accuracy: 1.0000
Epoch 95/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0420 - sparse_categorical_accuracy: 1.0000
Epoch 96/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0415 - sparse_categorical_accuracy: 1.0000
Epoch 97/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0409 - sparse_categorical_accuracy: 1.0000
Epoch 98/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0401 - sparse_categorical_accuracy: 1.0000
Epoch 99/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0396 - sparse_categorical_accuracy: 1.0000
Epoch 100/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0392 - sparse_categorical_accuracy: 1.0000
<keras.callbacks.History at 0x7ff7ac0f83d0>

Wizualizuj niepewność

Najpierw oblicz predykcyjne logity i wariancje.

sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_variance = tf.linalg.diag_part(sngp_covmat)[:, None]

Teraz oblicz prawdopodobieństwo predykcyjne a posteriori. Klasyczną metodą obliczania prawdopodobieństwa predykcyjnego modelu probabilistycznego jest użycie próbkowania Monte Carlo, tj.

\[E(p(x)) = \frac{1}{M} \sum_{m=1}^M logit_m(x), \]

gdzie \(M\) to wielkość próbki, a \(logit_m(x)\) są losowymi próbkami z SNGP a posteriori \(MultivariateNormal\)( sngp_logits , sngp_covmat ). Jednak takie podejście może być powolne w przypadku aplikacji wrażliwych na opóźnienia, takich jak jazda autonomiczna lub licytowanie w czasie rzeczywistym. Zamiast tego można przybliżyć \(E(p(x))\) przy użyciu metody pola średniego :

\[E(p(x)) \approx softmax(\frac{logit(x)}{\sqrt{1+ \lambda * \sigma^2(x)} })\]

gdzie \(\sigma^2(x)\) jest wariancją SNGP, a \(\lambda\) jest często wybierany jako \(\pi/8\) lub \(3/\pi^2\).

sngp_logits_adjusted = sngp_logits / tf.sqrt(1. + (np.pi / 8.) * sngp_variance)
sngp_probs = tf.nn.softmax(sngp_logits_adjusted, axis=-1)[:, 0]

Ta metoda średniego pola jest zaimplementowana jako wbudowana funkcja layers.gaussian_process.mean_field_logits :

def compute_posterior_mean_probability(logits, covmat, lambda_param=np.pi / 8.):
  # Computes uncertainty-adjusted logits using the built-in method.
  logits_adjusted = nlp_layers.gaussian_process.mean_field_logits(
      logits, covmat, mean_field_factor=lambda_param)

  return tf.nn.softmax(logits_adjusted, axis=-1)[:, 0]
sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

Podsumowanie SNGP

Połącz wszystko razem. Całą procedurę (uczenie, ocenę i obliczanie niepewności) można wykonać w zaledwie pięciu liniach:

def train_and_test_sngp(train_examples, test_examples):
  sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)

  sngp_model.compile(**train_config)
  sngp_model.fit(train_examples, train_labels, verbose=0, **fit_config)

  sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
  sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

  return sngp_probs
sngp_probs = train_and_test_sngp(train_examples, test_examples)

Wizualizuj prawdopodobieństwo klasy (po lewej) i niepewność predykcyjną (po prawej) modelu SNGP.

plot_predictions(sngp_probs, model_name="SNGP")

png

Pamiętaj, że na wykresie prawdopodobieństwa klas (po lewej) żółty i fioletowy to prawdopodobieństwa klas. Gdy znajduje się w pobliżu domeny danych uczących, SNGP poprawnie klasyfikuje przykłady z dużą pewnością (tj. przypisując prawdopodobieństwo bliskie 0 lub 1). Gdy znajduje się daleko od danych uczących, SNGP stopniowo staje się mniej pewny, a jego prawdopodobieństwo predykcyjne zbliża się do 0,5, podczas gdy (znormalizowana) niepewność modelu wzrasta do 1.

Porównaj to z powierzchnią niepewności modelu deterministycznego:

plot_predictions(resnet_probs, model_name="Deterministic")

png

Jak wspomniano wcześniej, model deterministyczny nie uwzględnia odległości . Jego niepewność jest określona przez odległość przykładu testowego od granicy decyzyjnej. Prowadzi to do tego, że model generuje zbyt pewne prognozy dla przykładów spoza domeny (kolor czerwony).

Porównanie z innymi podejściami do niepewności

Ta sekcja porównuje niepewność SNGP z dropoutem Monte Carlo i zespołem Deep .

Obie te metody opierają się na uśrednieniu metodą Monte Carlo wielu przejść do przodu modeli deterministycznych. Najpierw ustaw rozmiar zespołu \(M\).

num_ensemble = 10

Rezygnacja z Monte Carlo

Mając wytrenowaną sieć neuronową z warstwami porzucania, porzucenie Monte Carlo oblicza średnie prawdopodobieństwo predykcyjne

\[E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))\]

przez uśrednienie z wielu przejść do przodu z włączonym \(\{logit_m(x)\}_{m=1}^M\).

def mc_dropout_sampling(test_examples):
  # Enable dropout during inference.
  return resnet_model(test_examples, training=True)
# Monte Carlo dropout inference.
dropout_logit_samples = [mc_dropout_sampling(test_examples) for _ in range(num_ensemble)]
dropout_prob_samples = [tf.nn.softmax(dropout_logits, axis=-1)[:, 0] for dropout_logits in dropout_logit_samples]
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
plot_predictions(dropout_probs, model_name="MC Dropout")

png

Głęboki zespół

Deep Ensemble to najnowocześniejsza (ale droga) metoda głębokiego uczenia się niepewności. Aby wyszkolić zespół Deep, najpierw przeszkol członków zespołu \(M\) .

# Deep ensemble training
resnet_ensemble = []
for _ in range(num_ensemble):
  resnet_model = DeepResNet(**resnet_config)
  resnet_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
  resnet_model.fit(train_examples, train_labels, verbose=0, **fit_config)  

  resnet_ensemble.append(resnet_model)

Zbierz logity i oblicz średnie przewidywane prawdopodobieństwo \(E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))\).

# Deep ensemble inference
ensemble_logit_samples = [model(test_examples) for model in resnet_ensemble]
ensemble_prob_samples = [tf.nn.softmax(logits, axis=-1)[:, 0] for logits in ensemble_logit_samples]
ensemble_probs = tf.reduce_mean(ensemble_prob_samples, axis=0)
plot_predictions(ensemble_probs, model_name="Deep ensemble")

png

Zarówno MC Dropout, jak i Deep Ensemble poprawiają zdolność niepewności modelu, czyniąc granicę decyzyjną mniej pewną. Jednak obaj dziedziczą ograniczenie deterministycznej głębokiej sieci w postaci braku świadomości na odległość.

Streszczenie

W tym samouczku masz:

  • Zaimplementowano model SNGP na głębokim klasyfikatorze, aby poprawić jego świadomość odległości.
  • Kompleksowe szkolenie modelu SNGP przy użyciu interfejsu API Keras model.fit() .
  • Wizualizuje zachowanie niepewności protokołu SNGP.
  • Porównano zachowanie niepewności między modelami SNGP, Monte Carlo i głębokimi zestawami.

Zasoby i dalsze czytanie