ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

Uncertainty-aware Deep Learning with SNGP

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

In AI applications that are safety-critical (e.g., medical decision making and autonomous driving) or where the data is inherently noisy (e.g., natural language understanding), it is important for a deep classifier to reliably quantify its uncertainty. The deep classifier should be able to be aware of its own limitations and when it should hand control over to the human experts. This tutorial shows how to improve a deep classifier's ability in quantifying uncertainty using a technique called Spectral-normalized Neural Gaussian Process (SNGP).

The core idea of SNGP is to improve a deep classifier's distance awareness by applying simple modifications to the network. A model's distance awareness is a measure of how its predictive probability reflects the distance between the test example and the training data. This is a desirable property that is common for gold-standard probablistic models (e.g., the Gaussian process with RBF kernels) but is lacking in models with deep neural networks. SNGP provides a simple way to inject this Gaussian-process behavior into a deep classifier while maintaining its predictive accuracy.

This tutorial implements a deep residual network (ResNet)-based SNGP model on the two moons dataset, and compares its uncertainty surface with that of two other popular uncertainty approaches - Monte Carlo dropout and Deep ensemble).

This tutorial illustrates the SNGP model on a toy 2D dataset. For an example of applying SNGP to a real-world natural language understanding task using BERT-base, please see the SNGP-BERT tutorial. For high-quality implementations of SNGP model (and many other uncertainty methods) on a wide variety of benchmark datasets (e.g., CIFAR-100, ImageNet, Jigsaw toxicity detection, etc), please check out the Uncertainty Baselines benchmark.

About SNGP

Spectral-normalized Neural Gaussian Process (SNGP) is a simple approach to improve a deep classifier's uncertainty quality while maintaining a similar level of accuracy and latency. Given a deep residual network, SNGP makes two simple changes to the model:

  • It applies spectral normalization to the hidden residual layers.
  • It replaces the Dense output layer with a Gaussian process layer.

SNGP

Compared to other uncertainty approaches (e.g., Monte Carlo dropout or Deep ensemble), SNGP has several advantages:

  • It works for a wide range of state-of-the-art residual-based architectures (e.g., (Wide) ResNet, DenseNet, BERT, etc).
  • It is a single-model method (i.e., does not rely on ensemble averaging). Therefore SNGP has a similar level of latency as a single deterministic network, and can be scaled easily to large datasets like ImageNet and Jigsaw Toxic Comments classification.
  • It has strong out-of-domain detection performance due to the distance-awareness property.

The downsides of this method are:

  • The predictive uncertainty of a SNGP is computed using the Laplace approximation. Therefore theoretically, the posterior uncertainty of SNGP is different from that of an exact Gaussian process.

  • SNGP training needs a covariance reset step at the begining of a new epoch. This can add a tiny amount of extra complexity to a training pipeline. This tutorial shows a simple way to implement this using Keras callbacks.

Setup

pip uninstall -y tf-nightly keras-nightly
WARNING: Skipping tf-nightly as it is not installed.
WARNING: Skipping keras-nightly as it is not installed.
pip install tensorflow
pip install --use-deprecated=legacy-resolver tf-models-official
# refresh pkg_resources so it takes the changes into account.
import pkg_resources
import importlib
importlib.reload(pkg_resources)
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py'>
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import sklearn.datasets

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as nlp_layers

Define visualization macros

plt.rcParams['figure.dpi'] = 140

DEFAULT_X_RANGE = (-3.5, 3.5)
DEFAULT_Y_RANGE = (-2.5, 2.5)
DEFAULT_CMAP = colors.ListedColormap(["#377eb8", "#ff7f00"])
DEFAULT_NORM = colors.Normalize(vmin=0, vmax=1,)
DEFAULT_N_GRID = 100

The two moon dataset

Create the trainining and evaluation datasets from the two moon dataset.

def make_training_data(sample_size=500):
  """Create two moon training dataset."""
  train_examples, train_labels = sklearn.datasets.make_moons(
      n_samples=2 * sample_size, noise=0.1)

  # Adjust data position slightly.
  train_examples[train_labels == 0] += [-0.1, 0.2]
  train_examples[train_labels == 1] += [0.1, -0.2]

  return train_examples, train_labels

Evaluate the model's predictive behavior over the entire 2D input space.

def make_testing_data(x_range=DEFAULT_X_RANGE, y_range=DEFAULT_Y_RANGE, n_grid=DEFAULT_N_GRID):
  """Create a mesh grid in 2D space."""
  # testing data (mesh grid over data space)
  x = np.linspace(x_range[0], x_range[1], n_grid)
  y = np.linspace(y_range[0], y_range[1], n_grid)
  xv, yv = np.meshgrid(x, y)
  return np.stack([xv.flatten(), yv.flatten()], axis=-1)

To evaluate model uncertainty, add an out-of-domain (OOD) dataset that belongs to a third class. The model never sees these OOD examples during training.

def make_ood_data(sample_size=500, means=(2.5, -1.75), vars=(0.01, 0.01)):
  return np.random.multivariate_normal(
      means, cov=np.diag(vars), size=sample_size)
# Load the train, test and OOD datasets.
train_examples, train_labels = make_training_data(
    sample_size=500)
test_examples = make_testing_data()
ood_examples = make_ood_data(sample_size=500)

# Visualize
pos_examples = train_examples[train_labels == 0]
neg_examples = train_examples[train_labels == 1]

plt.figure(figsize=(7, 5.5))

plt.scatter(pos_examples[:, 0], pos_examples[:, 1], c="#377eb8", alpha=0.5)
plt.scatter(neg_examples[:, 0], neg_examples[:, 1], c="#ff7f00", alpha=0.5)
plt.scatter(ood_examples[:, 0], ood_examples[:, 1], c="red", alpha=0.1)

plt.legend(["Postive", "Negative", "Out-of-Domain"])

plt.ylim(DEFAULT_Y_RANGE)
plt.xlim(DEFAULT_X_RANGE)

plt.show()

png

Here the blue and orange represent the positive and negative classes, and the red represents the OOD data. A model that quantifies the uncertainty well is expected to be confident when close to training data (i.e., $p(x_{test})$ close to 0 or 1), and be uncertain when far away from the training data regions (i.e., $p(x_{test})$ close to 0.5).

The deterministic model

Define model

Start from the (baseline) deterministic model: a multi-layer residual network (ResNet) with dropout regularization.

This tutorial uses a 6-layer ResNet with 128 hidden units.

resnet_config = dict(num_classes=2, num_layers=6, num_hidden=128)
resnet_model = DeepResNet(**resnet_config)
resnet_model.build((None, 2))
resnet_model.summary()
Model: "deep_res_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  384       
_________________________________________________________________
dense_1 (Dense)              multiple                  16512     
_________________________________________________________________
dense_2 (Dense)              multiple                  16512     
_________________________________________________________________
dense_3 (Dense)              multiple                  16512     
_________________________________________________________________
dense_4 (Dense)              multiple                  16512     
_________________________________________________________________
dense_5 (Dense)              multiple                  16512     
_________________________________________________________________
dense_6 (Dense)              multiple                  16512     
_________________________________________________________________
dense_7 (Dense)              multiple                  258       
=================================================================
Total params: 99,714
Trainable params: 99,330
Non-trainable params: 384
_________________________________________________________________

Train model

Configure the training parameters to use SparseCategoricalCrossentropy as the loss function and the Adam optimizer.

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.keras.metrics.SparseCategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

train_config = dict(loss=loss, metrics=metrics, optimizer=optimizer)

Train the model for 100 epochs with batch size 128.

fit_config = dict(batch_size=128, epochs=100)
resnet_model.compile(**train_config)
resnet_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 1s 3ms/step - loss: 0.6322 - sparse_categorical_accuracy: 0.6320
Epoch 2/100
8/8 [==============================] - 0s 3ms/step - loss: 0.3328 - sparse_categorical_accuracy: 0.9070
Epoch 3/100
8/8 [==============================] - 0s 3ms/step - loss: 0.2253 - sparse_categorical_accuracy: 0.9310
Epoch 4/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1783 - sparse_categorical_accuracy: 0.9320
Epoch 5/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1461 - sparse_categorical_accuracy: 0.9450
Epoch 6/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1290 - sparse_categorical_accuracy: 0.9490
Epoch 7/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1252 - sparse_categorical_accuracy: 0.9460
Epoch 8/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1119 - sparse_categorical_accuracy: 0.9550
Epoch 9/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1097 - sparse_categorical_accuracy: 0.9550
Epoch 10/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1037 - sparse_categorical_accuracy: 0.9570
Epoch 11/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0959 - sparse_categorical_accuracy: 0.9550
Epoch 12/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0996 - sparse_categorical_accuracy: 0.9550
Epoch 13/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0933 - sparse_categorical_accuracy: 0.9560
Epoch 14/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0876 - sparse_categorical_accuracy: 0.9590
Epoch 15/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0867 - sparse_categorical_accuracy: 0.9580
Epoch 16/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0805 - sparse_categorical_accuracy: 0.9620
Epoch 17/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0843 - sparse_categorical_accuracy: 0.9610
Epoch 18/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0798 - sparse_categorical_accuracy: 0.9630
Epoch 19/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0784 - sparse_categorical_accuracy: 0.9630
Epoch 20/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9620
Epoch 21/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0765 - sparse_categorical_accuracy: 0.9610
Epoch 22/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0753 - sparse_categorical_accuracy: 0.9610
Epoch 23/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0730 - sparse_categorical_accuracy: 0.9610
Epoch 24/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0678 - sparse_categorical_accuracy: 0.9650
Epoch 25/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0734 - sparse_categorical_accuracy: 0.9600
Epoch 26/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0702 - sparse_categorical_accuracy: 0.9640
Epoch 27/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0652 - sparse_categorical_accuracy: 0.9620
Epoch 28/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0665 - sparse_categorical_accuracy: 0.9640
Epoch 29/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0631 - sparse_categorical_accuracy: 0.9650
Epoch 30/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0654 - sparse_categorical_accuracy: 0.9660
Epoch 31/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0616 - sparse_categorical_accuracy: 0.9680
Epoch 32/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0610 - sparse_categorical_accuracy: 0.9710
Epoch 33/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0645 - sparse_categorical_accuracy: 0.9670
Epoch 34/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0542 - sparse_categorical_accuracy: 0.9770
Epoch 35/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0599 - sparse_categorical_accuracy: 0.9740
Epoch 36/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0543 - sparse_categorical_accuracy: 0.9770
Epoch 37/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0570 - sparse_categorical_accuracy: 0.9750
Epoch 38/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0527 - sparse_categorical_accuracy: 0.9800
Epoch 39/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0534 - sparse_categorical_accuracy: 0.9760
Epoch 40/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0508 - sparse_categorical_accuracy: 0.9830
Epoch 41/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0501 - sparse_categorical_accuracy: 0.9820
Epoch 42/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0503 - sparse_categorical_accuracy: 0.9830
Epoch 43/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0430 - sparse_categorical_accuracy: 0.9850
Epoch 44/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0479 - sparse_categorical_accuracy: 0.9870
Epoch 45/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0485 - sparse_categorical_accuracy: 0.9880
Epoch 46/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9850
Epoch 47/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0428 - sparse_categorical_accuracy: 0.9880
Epoch 48/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0426 - sparse_categorical_accuracy: 0.9860
Epoch 49/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0447 - sparse_categorical_accuracy: 0.9870
Epoch 50/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0417 - sparse_categorical_accuracy: 0.9880
Epoch 51/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0406 - sparse_categorical_accuracy: 0.9850
Epoch 52/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0385 - sparse_categorical_accuracy: 0.9890
Epoch 53/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0357 - sparse_categorical_accuracy: 0.9880
Epoch 54/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0346 - sparse_categorical_accuracy: 0.9900
Epoch 55/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0364 - sparse_categorical_accuracy: 0.9910
Epoch 56/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0368 - sparse_categorical_accuracy: 0.9910
Epoch 57/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0338 - sparse_categorical_accuracy: 0.9890
Epoch 58/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0329 - sparse_categorical_accuracy: 0.9910
Epoch 59/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0320 - sparse_categorical_accuracy: 0.9920
Epoch 60/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0358 - sparse_categorical_accuracy: 0.9900
Epoch 61/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0375 - sparse_categorical_accuracy: 0.9900
Epoch 62/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0305 - sparse_categorical_accuracy: 0.9900
Epoch 63/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0315 - sparse_categorical_accuracy: 0.9940
Epoch 64/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0297 - sparse_categorical_accuracy: 0.9930
Epoch 65/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0288 - sparse_categorical_accuracy: 0.9920
Epoch 66/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0303 - sparse_categorical_accuracy: 0.9880
Epoch 67/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0262 - sparse_categorical_accuracy: 0.9960
Epoch 68/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0249 - sparse_categorical_accuracy: 0.9940
Epoch 69/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0290 - sparse_categorical_accuracy: 0.9910
Epoch 70/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0255 - sparse_categorical_accuracy: 0.9920
Epoch 71/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0262 - sparse_categorical_accuracy: 0.9920
Epoch 72/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0246 - sparse_categorical_accuracy: 0.9950
Epoch 73/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0296 - sparse_categorical_accuracy: 0.9920
Epoch 74/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0228 - sparse_categorical_accuracy: 0.9930
Epoch 75/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0235 - sparse_categorical_accuracy: 0.9940
Epoch 76/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0271 - sparse_categorical_accuracy: 0.9920
Epoch 77/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9910
Epoch 78/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9920
Epoch 79/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0207 - sparse_categorical_accuracy: 0.9970
Epoch 80/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9940
Epoch 81/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0232 - sparse_categorical_accuracy: 0.9930
Epoch 82/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0194 - sparse_categorical_accuracy: 0.9940
Epoch 83/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9930
Epoch 84/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0170 - sparse_categorical_accuracy: 0.9940
Epoch 85/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0163 - sparse_categorical_accuracy: 0.9950
Epoch 86/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0168 - sparse_categorical_accuracy: 0.9960
Epoch 87/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0161 - sparse_categorical_accuracy: 0.9960
Epoch 88/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0247 - sparse_categorical_accuracy: 0.9930
Epoch 89/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9970
Epoch 90/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0177 - sparse_categorical_accuracy: 0.9960
Epoch 91/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0157 - sparse_categorical_accuracy: 0.9960
Epoch 92/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0187 - sparse_categorical_accuracy: 0.9930
Epoch 93/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0165 - sparse_categorical_accuracy: 0.9950
Epoch 94/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0172 - sparse_categorical_accuracy: 0.9960
Epoch 95/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9950
Epoch 96/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0168 - sparse_categorical_accuracy: 0.9940
Epoch 97/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0195 - sparse_categorical_accuracy: 0.9960
Epoch 98/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0155 - sparse_categorical_accuracy: 0.9970
Epoch 99/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9940
Epoch 100/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9960
<keras.callbacks.History at 0x7efe887ab810>

Visualize uncertainty

Now visualize the predictions of the deterministic model. First plot the class probability:

$$p(x) = softmax(logit(x))$$
resnet_logits = resnet_model(test_examples)
resnet_probs = tf.nn.softmax(resnet_logits, axis=-1)[:, 0]  # Take the probability for class 0.
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_probs, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Class Probability, Deterministic Model")

plt.show()

png

In this plot, the yellow and purple are the predictive probabilities for the two classes. The deterministic model did a good job in classifying the two known classes (blue and orange) with a nonlinear decision boundary. However, it is not distance-aware, and classified the never-seen red out-of-domain (OOD) examples confidently as the orange class.

Visualize the model uncertainty by computing the predictive variance:

$$var(x) = p(x) * (1 - p(x))$$
resnet_uncertainty = resnet_probs * (1 - resnet_probs)
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_uncertainty, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Predictive Uncertainty, Deterministic Model")

plt.show()

png

In this plot, the yellow indicates high uncertainty, and the purple indicates low uncertainty. A deterministic ResNet's uncertainty depends only on the test examples' distance from the decision boundary. This leads the model to be over-confident when out of the training domain. The next section shows how SNGP behaves differently on this dataset.

The SNGP model

Define SNGP model

Let's now implement the SNGP model. Both the SNGP components, SpectralNormalization and RandomFeatureGaussianProcess, are available at the tensorflow_model's built-in layers.

SNGP

Let's look at these two components in more detail. (You can also jump to the The SNGP model section to see how the full model is implemented.)

Spectral Normalization wrapper

SpectralNormalization is a Keras layer wrapper. It can be applied to an existing Dense layer like this:

dense = tf.keras.layers.Dense(units=10)
dense = nlp_layers.SpectralNormalization(dense, norm_multiplier=0.9)

Spectral normalization regularizes the hidden weight $W$ by gradually guiding its spectral norm (i.e., the largest eigenvalue of $W$) toward the target value norm_multiplier.

The Gaussian Process (GP) layer

RandomFeatureGaussianProcess implements a random-feature based approximation to a Gaussian process model that is end-to-end trainable with a deep neural network. Under the hood, the Gaussian process layer implements a two-layer network:

$$logits(x) = \Phi(x) \beta, \quad \Phi(x)=\sqrt{\frac{2}{M} } * cos(Wx + b)$$

Here $x$ is the input, and $W$ and $b$ are frozen weights initialized randomly from Gaussian and uniform distributions, respectively. (Therefore $\Phi(x)$ are called "random features".) $\beta$ is the learnable kernel weight similar to that of a Dense layer.

batch_size = 32
input_dim = 1024
num_classes = 10
gp_layer = nlp_layers.RandomFeatureGaussianProcess(units=num_classes,
                                               num_inducing=1024,
                                               normalize_input=False,
                                               scale_random_features=True,
                                               gp_cov_momentum=-1)

The main parameters of the GP layers are:

  • units: The dimension of the output logits.
  • num_inducing: The dimension $M$ of the hidden weight $W$. Default to 1024.
  • normalize_input: Whether to apply layer normalization to the input $x$.
  • scale_random_features: Whether to apply the scale $\sqrt{2/M}$ to the hidden output.
  • gp_cov_momentum controls how the model covariance is computed. If set to a positive value (e.g., 0.999), the covariance matrix is computed using the momentum-based moving average update (similar to batch normalization). If set to -1, the the covariance matrix is updated without momentum.

Given a batch input with shape (batch_size, input_dim), the GP layer returns a logits tensor (shape (batch_size, num_classes)) for prediction, and also covmat tensor (shape (batch_size, batch_size)) which is the posterior covariance matrix of the batch logits.

embedding = tf.random.normal(shape=(batch_size, input_dim))

logits, covmat = gp_layer(embedding)

Theoretically, it is possible to extend the algorithm to compute different variance values for different classes (as introduced in the original SNGP paper). However, this is difficult to scale to problems with large output spaces (e.g., ImageNet or language modeling).

The full SNGP model

Given the base class DeepResNet, the SNGP model can be implemented easily by modifying the residual network's hidden and output layers. For compatibility with Keras model.fit() API, also modify the model's call() method so it only outputs logits during training.

class DeepResNetSNGP(DeepResNet):
  def __init__(self, spec_norm_bound=0.9, **kwargs):
    self.spec_norm_bound = spec_norm_bound
    super().__init__(**kwargs)

  def make_dense_layer(self):
    """Applies spectral normalization to the hidden layer."""
    dense_layer = super().make_dense_layer()
    return nlp_layers.SpectralNormalization(
        dense_layer, norm_multiplier=self.spec_norm_bound)

  def make_output_layer(self, num_classes):
    """Uses Gaussian process as the output layer."""
    return nlp_layers.RandomFeatureGaussianProcess(
        num_classes, 
        gp_cov_momentum=-1,
        **self.classifier_kwargs)

  def call(self, inputs, training=False, return_covmat=False):
    # Gets logits and covariance matrix from GP layer.
    logits, covmat = super().call(inputs)

    # Returns only logits during training.
    if not training and return_covmat:
      return logits, covmat

    return logits

Use the same architecture as the deterministic model.

resnet_config
{'num_classes': 2, 'num_layers': 6, 'num_hidden': 128}
sngp_model = DeepResNetSNGP(**resnet_config)
sngp_model.build((None, 2))
sngp_model.summary()
Model: "deep_res_net_sngp"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              multiple                  384       
_________________________________________________________________
spectral_normalization_1 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_2 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_3 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_4 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_5 (Sp multiple                  16768     
_________________________________________________________________
spectral_normalization_6 (Sp multiple                  16768     
_________________________________________________________________
random_feature_gaussian_proc multiple                  1182722   
=================================================================
Total params: 1,283,714
Trainable params: 101,120
Non-trainable params: 1,182,594
_________________________________________________________________

Implement a Keras callback to reset the covariance matrix at the beginning of a new epoch.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()

Add this callback to the DeepResNetSNGP model class.

class DeepResNetSNGPWithCovReset(DeepResNetSNGP):
  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs["callbacks"] = list(kwargs.get("callbacks", []))
    kwargs["callbacks"].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Train model

Use tf.keras.model.fit to train the model.

sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)
sngp_model.compile(**train_config)
sngp_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 1s 5ms/step - loss: 0.6257 - sparse_categorical_accuracy: 0.9565
Epoch 2/100
8/8 [==============================] - 0s 4ms/step - loss: 0.5258 - sparse_categorical_accuracy: 0.9980
Epoch 3/100
8/8 [==============================] - 0s 4ms/step - loss: 0.4710 - sparse_categorical_accuracy: 0.9980
Epoch 4/100
8/8 [==============================] - 0s 4ms/step - loss: 0.4303 - sparse_categorical_accuracy: 0.9980
Epoch 5/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3978 - sparse_categorical_accuracy: 0.9990
Epoch 6/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3695 - sparse_categorical_accuracy: 0.9990
Epoch 7/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3449 - sparse_categorical_accuracy: 0.9990
Epoch 8/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3238 - sparse_categorical_accuracy: 0.9990
Epoch 9/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3064 - sparse_categorical_accuracy: 0.9990
Epoch 10/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2884 - sparse_categorical_accuracy: 0.9990
Epoch 11/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2725 - sparse_categorical_accuracy: 0.9980
Epoch 12/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2582 - sparse_categorical_accuracy: 0.9980
Epoch 13/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2467 - sparse_categorical_accuracy: 0.9980
Epoch 14/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2347 - sparse_categorical_accuracy: 0.9970
Epoch 15/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2241 - sparse_categorical_accuracy: 0.9980
Epoch 16/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2137 - sparse_categorical_accuracy: 0.9990
Epoch 17/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2048 - sparse_categorical_accuracy: 0.9980
Epoch 18/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1973 - sparse_categorical_accuracy: 0.9980
Epoch 19/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1882 - sparse_categorical_accuracy: 0.9990
Epoch 20/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1824 - sparse_categorical_accuracy: 0.9970
Epoch 21/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1753 - sparse_categorical_accuracy: 0.9990
Epoch 22/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1694 - sparse_categorical_accuracy: 0.9980
Epoch 23/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1639 - sparse_categorical_accuracy: 0.9980
Epoch 24/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1580 - sparse_categorical_accuracy: 0.9990
Epoch 25/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1530 - sparse_categorical_accuracy: 0.9990
Epoch 26/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1479 - sparse_categorical_accuracy: 0.9980
Epoch 27/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1429 - sparse_categorical_accuracy: 0.9990
Epoch 28/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1383 - sparse_categorical_accuracy: 0.9990
Epoch 29/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1350 - sparse_categorical_accuracy: 0.9990
Epoch 30/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1311 - sparse_categorical_accuracy: 0.9980
Epoch 31/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1274 - sparse_categorical_accuracy: 0.9990
Epoch 32/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1233 - sparse_categorical_accuracy: 0.9990
Epoch 33/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1196 - sparse_categorical_accuracy: 0.9990
Epoch 34/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1172 - sparse_categorical_accuracy: 0.9990
Epoch 35/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1135 - sparse_categorical_accuracy: 0.9990
Epoch 36/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1103 - sparse_categorical_accuracy: 0.9990
Epoch 37/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1073 - sparse_categorical_accuracy: 0.9990
Epoch 38/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1044 - sparse_categorical_accuracy: 0.9990
Epoch 39/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1022 - sparse_categorical_accuracy: 0.9990
Epoch 40/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0987 - sparse_categorical_accuracy: 0.9990
Epoch 41/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0970 - sparse_categorical_accuracy: 0.9990
Epoch 42/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0946 - sparse_categorical_accuracy: 0.9990
Epoch 43/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0925 - sparse_categorical_accuracy: 0.9990
Epoch 44/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0897 - sparse_categorical_accuracy: 0.9990
Epoch 45/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0881 - sparse_categorical_accuracy: 0.9990
Epoch 46/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0860 - sparse_categorical_accuracy: 0.9990
Epoch 47/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0834 - sparse_categorical_accuracy: 0.9990
Epoch 48/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0824 - sparse_categorical_accuracy: 0.9990
Epoch 49/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0805 - sparse_categorical_accuracy: 0.9990
Epoch 50/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0796 - sparse_categorical_accuracy: 0.9990
Epoch 51/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0775 - sparse_categorical_accuracy: 0.9990
Epoch 52/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0752 - sparse_categorical_accuracy: 0.9990
Epoch 53/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0738 - sparse_categorical_accuracy: 0.9990
Epoch 54/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0726 - sparse_categorical_accuracy: 0.9990
Epoch 55/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0706 - sparse_categorical_accuracy: 0.9990
Epoch 56/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0691 - sparse_categorical_accuracy: 0.9990
Epoch 57/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0677 - sparse_categorical_accuracy: 1.0000
Epoch 58/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0664 - sparse_categorical_accuracy: 0.9990
Epoch 59/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0655 - sparse_categorical_accuracy: 1.0000
Epoch 60/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0642 - sparse_categorical_accuracy: 0.9990
Epoch 61/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0628 - sparse_categorical_accuracy: 1.0000
Epoch 62/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0616 - sparse_categorical_accuracy: 0.9990
Epoch 63/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0607 - sparse_categorical_accuracy: 0.9990
Epoch 64/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0594 - sparse_categorical_accuracy: 0.9990
Epoch 65/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0582 - sparse_categorical_accuracy: 1.0000
Epoch 66/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0568 - sparse_categorical_accuracy: 1.0000
Epoch 67/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0558 - sparse_categorical_accuracy: 1.0000
Epoch 68/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0548 - sparse_categorical_accuracy: 1.0000
Epoch 69/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0541 - sparse_categorical_accuracy: 0.9990
Epoch 70/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0529 - sparse_categorical_accuracy: 1.0000
Epoch 71/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0519 - sparse_categorical_accuracy: 1.0000
Epoch 72/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0511 - sparse_categorical_accuracy: 1.0000
Epoch 73/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0505 - sparse_categorical_accuracy: 0.9990
Epoch 74/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0492 - sparse_categorical_accuracy: 1.0000
Epoch 75/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0491 - sparse_categorical_accuracy: 1.0000
Epoch 76/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0477 - sparse_categorical_accuracy: 1.0000
Epoch 77/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0471 - sparse_categorical_accuracy: 1.0000
Epoch 78/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0467 - sparse_categorical_accuracy: 0.9990
Epoch 79/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0458 - sparse_categorical_accuracy: 1.0000
Epoch 80/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0446 - sparse_categorical_accuracy: 1.0000
Epoch 81/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0441 - sparse_categorical_accuracy: 1.0000
Epoch 82/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0436 - sparse_categorical_accuracy: 1.0000
Epoch 83/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0431 - sparse_categorical_accuracy: 1.0000
Epoch 84/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0421 - sparse_categorical_accuracy: 1.0000
Epoch 85/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0419 - sparse_categorical_accuracy: 1.0000
Epoch 86/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0408 - sparse_categorical_accuracy: 1.0000
Epoch 87/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0402 - sparse_categorical_accuracy: 1.0000
Epoch 88/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0396 - sparse_categorical_accuracy: 1.0000
Epoch 89/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0392 - sparse_categorical_accuracy: 1.0000
Epoch 90/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0388 - sparse_categorical_accuracy: 1.0000
Epoch 91/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0381 - sparse_categorical_accuracy: 1.0000
Epoch 92/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0375 - sparse_categorical_accuracy: 1.0000
Epoch 93/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0370 - sparse_categorical_accuracy: 1.0000
Epoch 94/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0365 - sparse_categorical_accuracy: 1.0000
Epoch 95/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0359 - sparse_categorical_accuracy: 1.0000
Epoch 96/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0349 - sparse_categorical_accuracy: 1.0000
Epoch 97/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0350 - sparse_categorical_accuracy: 0.9990
Epoch 98/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0343 - sparse_categorical_accuracy: 1.0000
Epoch 99/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0342 - sparse_categorical_accuracy: 1.0000
Epoch 100/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0337 - sparse_categorical_accuracy: 1.0000
<keras.callbacks.History at 0x7efe8808dc90>

Visualize uncertainty

First compute the predictive logits and variances.

sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_variance = tf.linalg.diag_part(sngp_covmat)[:, None]

Now compute the posterior predictive probability. The classic method for computing the predictive probability of a probabilistic model is to use Monte Carlo sampling, i.e.,

$$E(p(x)) = \frac{1}{M} \sum_{m=1}^M logit_m(x), $$

where $M$ is the sample size, and $logit_m(x)$ are random samples from the SNGP posterior $MultivariateNormal$(sngp_logits,sngp_covmat). However, this approach can be slow for latency-sensitive applications such as autonomous driving or real-time bidding. Instead, can approximate $E(p(x))$ using the mean-field method:

$$E(p(x)) \approx softmax(\frac{logit(x)}{\sqrt{1+ \lambda * \sigma^2(x)} })$$

where $\sigma^2(x)$ is the SNGP variance, and $\lambda$ is often chosen as $\pi/8$ or $3/\pi^2$.

sngp_logits_adjusted = sngp_logits / tf.sqrt(1. + (np.pi / 8.) * sngp_variance)
sngp_probs = tf.nn.softmax(sngp_logits_adjusted, axis=-1)[:, 0]

This mean-field method is implemented as a built-in function layers.gaussian_process.mean_field_logits:

def compute_posterior_mean_probability(logits, covmat, lambda_param=np.pi / 8.):
  # Computes uncertainty-adjusted logits using the built-in method.
  logits_adjusted = nlp_layers.gaussian_process.mean_field_logits(
      logits, covmat, mean_field_factor=lambda_param)

  return tf.nn.softmax(logits_adjusted, axis=-1)[:, 0]
sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

SNGP Summary

Put everything together. The entire procedure (training, evaluation and uncertainty computation) can be done in just five lines:

def train_and_test_sngp(train_examples, test_examples):
  sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)

  sngp_model.compile(**train_config)
  sngp_model.fit(train_examples, train_labels, verbose=0, **fit_config)

  sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
  sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

  return sngp_probs
sngp_probs = train_and_test_sngp(train_examples, test_examples)

Visualize the class probability (left) and the predictive uncertainty (right) of the SNGP model.

plot_predictions(sngp_probs, model_name="SNGP")

png

Remember that in the class probability plot (left), the yellow and purple are class probabilities. When close to the training data domain, SNGP correctly classifies the examples with high confidence (i.e., assigning near 0 or 1 probability). When far away from the training data, SNGP gradually becomes less confident, and its predictive probability becomes close to 0.5 while the (normalized) model uncertainty rises to 1.

Compare this to the uncertainty surface of the deterministic model:

plot_predictions(resnet_probs, model_name="Deterministic")

png

Like mentioned earlier, a deterministic model is not distance-aware. Its uncertainty is defined by the distance of the test example from the decision boundary. This leads the model to produce overconfident predictions for the out-of-domain examples (red).

Comparison with other uncertainty approaches

This section compares the uncertainty of SNGP with Monte Carlo dropout and Deep ensemble.

Both of these methods are based on Monte Carlo averaging of multiple forward passes of deterministic models. First set the ensemble size $M$.

num_ensemble = 10

Monte Carlo dropout

Given a trained neural network with Dropout layers, Monte Carlo dropout computes the mean predictive probability

$$E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))$$

by averaging over multiple Dropout-enabled forward passes ${logit_m(x)}_{m=1}^M$.

def mc_dropout_sampling(test_examples):
  # Enable dropout during inference.
  return resnet_model(test_examples, training=True)
# Monte Carlo dropout inference.
dropout_logit_samples = [mc_dropout_sampling(test_examples) for _ in range(num_ensemble)]
dropout_prob_samples = [tf.nn.softmax(dropout_logits, axis=-1)[:, 0] for dropout_logits in dropout_logit_samples]
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
plot_predictions(dropout_probs, model_name="MC Dropout")

png

Deep ensemble

Deep ensemble is a state-of-the-art (but expensive) method for deep learning uncertainty. To train a Deep ensemble, first train $M$ ensemble members.

# Deep ensemble training
resnet_ensemble = []
for _ in range(num_ensemble):
  resnet_model = DeepResNet(**resnet_config)
  resnet_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
  resnet_model.fit(train_examples, train_labels, verbose=0, **fit_config)  

  resnet_ensemble.append(resnet_model)

Collect logits and compute the mean predctive probability $E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))$.

# Deep ensemble inference
ensemble_logit_samples = [model(test_examples) for model in resnet_ensemble]
ensemble_prob_samples = [tf.nn.softmax(logits, axis=-1)[:, 0] for logits in ensemble_logit_samples]
ensemble_probs = tf.reduce_mean(ensemble_prob_samples, axis=0)
plot_predictions(ensemble_probs, model_name="Deep ensemble")

png

Both MC Dropout and Deep ensemble improve a model's uncertainty ability by making the decision boundary less certain. However, they both inherit the deterministic deep network's limitation in lacking distance awareness.

Summary

In this tutorial, you have:

  • Implemented a SNGP model on a deep classifier to improve its distance awareness.
  • Trained the SNGP model end-to-end using Keras model.fit() API.
  • Visualized the uncertainty behavior of SNGP.
  • Compared the uncertainty behavior between SNGP, Monte Carlo dropout and deep ensemble models.

Resources and further reading