Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Aprendizaje profundo consciente de la incertidumbre con SNGP

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno

En aplicaciones de IA que son críticas para la seguridad (p. Ej., Toma de decisiones médicas y conducción autónoma) o donde los datos son intrínsecamente ruidosos (p. Ej., Comprensión del lenguaje natural), es importante que un clasificador profundo cuantifique de manera confiable su incertidumbre. El clasificador profundo debe poder ser consciente de sus propias limitaciones y cuándo debe ceder el control a los expertos humanos. Este tutorial muestra cómo mejorar la capacidad de un clasificador de profundidad en la cuantificación de la incertidumbre mediante una técnica llamada normalizado-espectral Proceso Neural Gauss ( SNGP ).

La idea central de SNGP es mejorar la conciencia la distancia a un clasificador de profundidad mediante la aplicación de simples modificaciones a la red. Conciencia de la distancia a un modelo es una medida de cómo su probabilidad de predicción refleja la distancia entre el ejemplo de prueba y los datos de entrenamiento. Esta es una propiedad deseable que es común para los modelos probabilístico estándar de oro (por ejemplo, el proceso de Gauss con RBF granos), pero está ausente en los modelos con redes neuronales profundos. SNGP proporciona una forma sencilla de inyectar este comportamiento de proceso gaussiano en un clasificador profundo mientras se mantiene su precisión predictiva.

En este tutorial se implementa una red residual (ResNet) basado en modelo SNGP profunda sobre las dos lunas conjunto de datos, y compara su superficie incertidumbre con la de otros dos populares incertidumbre acerca - Monte Carlo deserción y ensamble profundo ).

Este tutorial ilustra el modelo SNGP en un conjunto de datos 2D de juguete. Para un ejemplo de la aplicación de SNGP a un mundo real comprensión del lenguaje natural tarea usando BERT-base, consulte el tutorial SNGP-BERT . Para las implementaciones de alta calidad de modelo SNGP (y muchos otros métodos de incertidumbre) en una amplia variedad de conjuntos de datos de referencia (por ejemplo, CIFAR-100 , IMAGEnet , detección de toxicidad Jigsaw , etc), por favor visita la líneas de base de la incertidumbre de referencia.

Sobre SNGP

Normalizado-espectral Proceso Neural Gauss (SNGP) es un método sencillo para mejorar la calidad de incertidumbre de un clasificador de profundidad, manteniendo un nivel similar de precisión y latencia. Dada una red residual profunda, SNGP realiza dos cambios simples en el modelo:

  • Aplica normalización espectral a las capas residuales ocultas.
  • Reemplaza la capa de salida densa con una capa de proceso gaussiano.

SNGP

En comparación con otros enfoques de incertidumbre (p. Ej., Abandono de Monte Carlo o conjunto profundo), SNGP tiene varias ventajas:

  • Funciona para una amplia gama de arquitecturas basadas en residuos de última generación (por ejemplo, (Wide) ResNet, DenseNet, BERT, etc.).
  • Es un método de modelo único (es decir, no se basa en promedios de conjuntos). Por lo tanto SNGP tiene un nivel similar de latencia como una sola red determinista, y se puede escalar fácilmente a grandes conjuntos de datos como IMAGEnet y Jigsaw tóxicos Comentarios clasificación .
  • Tiene una fuerte capacidad de detección fuera de dominio debido a la propiedad distancia conciencia.

Las desventajas de este método son:

  • La incertidumbre de las predicciones de un SNGP se calcula utilizando la aproximación de Laplace . Por lo tanto, teóricamente, la incertidumbre posterior de SNGP es diferente de la de un proceso gaussiano exacto.

  • El entrenamiento SNGP necesita un paso de restablecimiento de covarianza al comienzo de una nueva época. Esto puede agregar una pequeña cantidad de complejidad adicional a un proceso de capacitación. Este tutorial muestra una forma sencilla de implementar esto usando devoluciones de llamada de Keras.

Configuración

pip install --use-deprecated=legacy-resolver tf-models-official
# refresh pkg_resources so it takes the changes into account.
import pkg_resources
import importlib
importlib.reload(pkg_resources)
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py'>
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import sklearn.datasets

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as nlp_layers

Definir macros de visualización

plt.rcParams['figure.dpi'] = 140

DEFAULT_X_RANGE = (-3.5, 3.5)
DEFAULT_Y_RANGE = (-2.5, 2.5)
DEFAULT_CMAP = colors.ListedColormap(["#377eb8", "#ff7f00"])
DEFAULT_NORM = colors.Normalize(vmin=0, vmax=1,)
DEFAULT_N_GRID = 100

El conjunto de datos de las dos lunas

Crear las trainining y evaluación de datos del conjunto de datos de dos luna .

def make_training_data(sample_size=500):
  """Create two moon training dataset."""
  train_examples, train_labels = sklearn.datasets.make_moons(
      n_samples=2 * sample_size, noise=0.1)

  # Adjust data position slightly.
  train_examples[train_labels == 0] += [-0.1, 0.2]
  train_examples[train_labels == 1] += [0.1, -0.2]

  return train_examples, train_labels

Evalúe el comportamiento predictivo del modelo en todo el espacio de entrada 2D.

def make_testing_data(x_range=DEFAULT_X_RANGE, y_range=DEFAULT_Y_RANGE, n_grid=DEFAULT_N_GRID):
  """Create a mesh grid in 2D space."""
  # testing data (mesh grid over data space)
  x = np.linspace(x_range[0], x_range[1], n_grid)
  y = np.linspace(y_range[0], y_range[1], n_grid)
  xv, yv = np.meshgrid(x, y)
  return np.stack([xv.flatten(), yv.flatten()], axis=-1)

Para evaluar la incertidumbre del modelo, agregue un conjunto de datos fuera del dominio (OOD) que pertenezca a una tercera clase. El modelo nunca ve estos ejemplos de OOD durante el entrenamiento.

def make_ood_data(sample_size=500, means=(2.5, -1.75), vars=(0.01, 0.01)):
  return np.random.multivariate_normal(
      means, cov=np.diag(vars), size=sample_size)
# Load the train, test and OOD datasets.
train_examples, train_labels = make_training_data(
    sample_size=500)
test_examples = make_testing_data()
ood_examples = make_ood_data(sample_size=500)

# Visualize
pos_examples = train_examples[train_labels == 0]
neg_examples = train_examples[train_labels == 1]

plt.figure(figsize=(7, 5.5))

plt.scatter(pos_examples[:, 0], pos_examples[:, 1], c="#377eb8", alpha=0.5)
plt.scatter(neg_examples[:, 0], neg_examples[:, 1], c="#ff7f00", alpha=0.5)
plt.scatter(ood_examples[:, 0], ood_examples[:, 1], c="red", alpha=0.1)

plt.legend(["Postive", "Negative", "Out-of-Domain"])

plt.ylim(DEFAULT_Y_RANGE)
plt.xlim(DEFAULT_X_RANGE)

plt.show()

png

Aquí, el azul y el naranja representan las clases positivas y negativas, y el rojo representa los datos OOD. Se espera que un modelo que cuantifica así la incertidumbre para estar seguros cuando cerca de datos de entrenamiento (es decir, \(p(x_{test})\) cerca de 0 o 1), y ser incierto cuando lejos de las regiones de datos de entrenamiento (es decir, \(p(x_{test})\) cerca de 0.5 ).

El modelo determinista

Definir modelo

Comience desde el modelo determinista (de referencia): una red residual multicapa (ResNet) con regularización de abandonos.

Este tutorial utiliza una ResNet de 6 capas con 128 unidades ocultas.

resnet_config = dict(num_classes=2, num_layers=6, num_hidden=128)
resnet_model = DeepResNet(**resnet_config)
resnet_model.build((None, 2))
resnet_model.summary()
Model: "deep_res_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               multiple                  384       
                                                                 
 dense_1 (Dense)             multiple                  16512     
                                                                 
 dense_2 (Dense)             multiple                  16512     
                                                                 
 dense_3 (Dense)             multiple                  16512     
                                                                 
 dense_4 (Dense)             multiple                  16512     
                                                                 
 dense_5 (Dense)             multiple                  16512     
                                                                 
 dense_6 (Dense)             multiple                  16512     
                                                                 
 dense_7 (Dense)             multiple                  258       
                                                                 
=================================================================
Total params: 99,714
Trainable params: 99,330
Non-trainable params: 384
_________________________________________________________________

Modelo de tren

Configurar los parámetros de entrenamiento para usar SparseCategoricalCrossentropy como la función de pérdida y el optimizador de Adán.

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.keras.metrics.SparseCategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

train_config = dict(loss=loss, metrics=metrics, optimizer=optimizer)

Entrene el modelo durante 100 épocas con un tamaño de lote 128.

fit_config = dict(batch_size=128, epochs=100)
resnet_model.compile(**train_config)
resnet_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 1s 4ms/step - loss: 1.0062 - sparse_categorical_accuracy: 0.4330
Epoch 2/100
8/8 [==============================] - 0s 3ms/step - loss: 0.4599 - sparse_categorical_accuracy: 0.8310
Epoch 3/100
8/8 [==============================] - 0s 3ms/step - loss: 0.2997 - sparse_categorical_accuracy: 0.8950
Epoch 4/100
8/8 [==============================] - 0s 3ms/step - loss: 0.2249 - sparse_categorical_accuracy: 0.9180
Epoch 5/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1722 - sparse_categorical_accuracy: 0.9310
Epoch 6/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1575 - sparse_categorical_accuracy: 0.9370
Epoch 7/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1448 - sparse_categorical_accuracy: 0.9410
Epoch 8/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1347 - sparse_categorical_accuracy: 0.9400
Epoch 9/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1293 - sparse_categorical_accuracy: 0.9440
Epoch 10/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1239 - sparse_categorical_accuracy: 0.9440
Epoch 11/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1165 - sparse_categorical_accuracy: 0.9390
Epoch 12/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1176 - sparse_categorical_accuracy: 0.9460
Epoch 13/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1170 - sparse_categorical_accuracy: 0.9400
Epoch 14/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1101 - sparse_categorical_accuracy: 0.9470
Epoch 15/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1049 - sparse_categorical_accuracy: 0.9460
Epoch 16/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1035 - sparse_categorical_accuracy: 0.9480
Epoch 17/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1003 - sparse_categorical_accuracy: 0.9500
Epoch 18/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1015 - sparse_categorical_accuracy: 0.9530
Epoch 19/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0973 - sparse_categorical_accuracy: 0.9490
Epoch 20/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0951 - sparse_categorical_accuracy: 0.9520
Epoch 21/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0962 - sparse_categorical_accuracy: 0.9520
Epoch 22/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0895 - sparse_categorical_accuracy: 0.9560
Epoch 23/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0856 - sparse_categorical_accuracy: 0.9580
Epoch 24/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0865 - sparse_categorical_accuracy: 0.9550
Epoch 25/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0838 - sparse_categorical_accuracy: 0.9590
Epoch 26/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0883 - sparse_categorical_accuracy: 0.9560
Epoch 27/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0837 - sparse_categorical_accuracy: 0.9590
Epoch 28/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0884 - sparse_categorical_accuracy: 0.9590
Epoch 29/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0861 - sparse_categorical_accuracy: 0.9590
Epoch 30/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0777 - sparse_categorical_accuracy: 0.9650
Epoch 31/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0752 - sparse_categorical_accuracy: 0.9670
Epoch 32/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0767 - sparse_categorical_accuracy: 0.9660
Epoch 33/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0759 - sparse_categorical_accuracy: 0.9660
Epoch 34/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0693 - sparse_categorical_accuracy: 0.9700
Epoch 35/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0703 - sparse_categorical_accuracy: 0.9670
Epoch 36/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0824 - sparse_categorical_accuracy: 0.9660
Epoch 37/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9690
Epoch 38/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0683 - sparse_categorical_accuracy: 0.9710
Epoch 39/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0666 - sparse_categorical_accuracy: 0.9740
Epoch 40/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0709 - sparse_categorical_accuracy: 0.9700
Epoch 41/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0639 - sparse_categorical_accuracy: 0.9790
Epoch 42/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0656 - sparse_categorical_accuracy: 0.9770
Epoch 43/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0624 - sparse_categorical_accuracy: 0.9800
Epoch 44/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0660 - sparse_categorical_accuracy: 0.9740
Epoch 45/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0605 - sparse_categorical_accuracy: 0.9790
Epoch 46/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0596 - sparse_categorical_accuracy: 0.9810
Epoch 47/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0619 - sparse_categorical_accuracy: 0.9790
Epoch 48/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0627 - sparse_categorical_accuracy: 0.9780
Epoch 49/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0569 - sparse_categorical_accuracy: 0.9800
Epoch 50/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0549 - sparse_categorical_accuracy: 0.9870
Epoch 51/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0555 - sparse_categorical_accuracy: 0.9790
Epoch 52/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0583 - sparse_categorical_accuracy: 0.9860
Epoch 53/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0516 - sparse_categorical_accuracy: 0.9780
Epoch 54/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0501 - sparse_categorical_accuracy: 0.9850
Epoch 55/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0535 - sparse_categorical_accuracy: 0.9800
Epoch 56/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0483 - sparse_categorical_accuracy: 0.9850
Epoch 57/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0504 - sparse_categorical_accuracy: 0.9840
Epoch 58/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0449 - sparse_categorical_accuracy: 0.9880
Epoch 59/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0479 - sparse_categorical_accuracy: 0.9840
Epoch 60/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0491 - sparse_categorical_accuracy: 0.9840
Epoch 61/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0432 - sparse_categorical_accuracy: 0.9850
Epoch 62/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0434 - sparse_categorical_accuracy: 0.9860
Epoch 63/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0391 - sparse_categorical_accuracy: 0.9890
Epoch 64/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0391 - sparse_categorical_accuracy: 0.9880
Epoch 65/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0423 - sparse_categorical_accuracy: 0.9860
Epoch 66/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0385 - sparse_categorical_accuracy: 0.9860
Epoch 67/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9850
Epoch 68/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0440 - sparse_categorical_accuracy: 0.9860
Epoch 69/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0438 - sparse_categorical_accuracy: 0.9840
Epoch 70/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0412 - sparse_categorical_accuracy: 0.9860
Epoch 71/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9920
Epoch 72/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0397 - sparse_categorical_accuracy: 0.9820
Epoch 73/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0366 - sparse_categorical_accuracy: 0.9910
Epoch 74/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0351 - sparse_categorical_accuracy: 0.9870
Epoch 75/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0385 - sparse_categorical_accuracy: 0.9870
Epoch 76/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9890
Epoch 77/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0290 - sparse_categorical_accuracy: 0.9900
Epoch 78/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0351 - sparse_categorical_accuracy: 0.9880
Epoch 79/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0323 - sparse_categorical_accuracy: 0.9900
Epoch 80/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0304 - sparse_categorical_accuracy: 0.9900
Epoch 81/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0284 - sparse_categorical_accuracy: 0.9910
Epoch 82/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0328 - sparse_categorical_accuracy: 0.9870
Epoch 83/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0328 - sparse_categorical_accuracy: 0.9860
Epoch 84/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0269 - sparse_categorical_accuracy: 0.9900
Epoch 85/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9890
Epoch 86/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0236 - sparse_categorical_accuracy: 0.9910
Epoch 87/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0340 - sparse_categorical_accuracy: 0.9900
Epoch 88/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0320 - sparse_categorical_accuracy: 0.9890
Epoch 89/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0247 - sparse_categorical_accuracy: 0.9910
Epoch 90/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0271 - sparse_categorical_accuracy: 0.9920
Epoch 91/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9920
Epoch 92/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0224 - sparse_categorical_accuracy: 0.9920
Epoch 93/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0252 - sparse_categorical_accuracy: 0.9900
Epoch 94/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0272 - sparse_categorical_accuracy: 0.9920
Epoch 95/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0288 - sparse_categorical_accuracy: 0.9920
Epoch 96/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0242 - sparse_categorical_accuracy: 0.9930
Epoch 97/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0237 - sparse_categorical_accuracy: 0.9910
Epoch 98/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0198 - sparse_categorical_accuracy: 0.9950
Epoch 99/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9950
Epoch 100/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0175 - sparse_categorical_accuracy: 0.9940
<keras.callbacks.History at 0x7ffb20184710>

Visualiza la incertidumbre

Ahora visualice las predicciones del modelo determinista. Primero grafique la probabilidad de la clase:

\[p(x) = softmax(logit(x))\]

resnet_logits = resnet_model(test_examples)
resnet_probs = tf.nn.softmax(resnet_logits, axis=-1)[:, 0]  # Take the probability for class 0.
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_probs, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Class Probability, Deterministic Model")

plt.show()

png

En esta gráfica, el amarillo y el morado son las probabilidades predictivas para las dos clases. El modelo determinista hizo un buen trabajo al clasificar las dos clases conocidas (azul y naranja) con un límite de decisión no lineal. Sin embargo, no es consciente de distancia, y clasifica los ejemplos nunca vistos-roja fuera del dominio (OOD) con confianza que la clase de color naranja.

Visualizar el modelo de la incertidumbre mediante el cálculo de la varianza de predicción :

\[var(x) = p(x) * (1 - p(x))\]

resnet_uncertainty = resnet_probs * (1 - resnet_probs)
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_uncertainty, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Predictive Uncertainty, Deterministic Model")

plt.show()

png

En este gráfico, el amarillo indica alta incertidumbre y el violeta indica baja incertidumbre. La incertidumbre de una ResNet determinista depende solo de la distancia de los ejemplos de prueba desde el límite de decisión. Esto hace que el modelo tenga demasiada confianza cuando está fuera del dominio de entrenamiento. La siguiente sección muestra cómo SNGP se comporta de manera diferente en este conjunto de datos.

El modelo SNGP

Definir el modelo SNGP

Implementemos ahora el modelo SNGP. Tanto los componentes SNGP, SpectralNormalization y RandomFeatureGaussianProcess , están disponibles en el tensorflow_model que está incorporado en capas .

SNGP

Veamos estos dos componentes con más detalle. (También puede saltar a la modelo SNGP La sección para ver cómo se implementa el modelo completo.)

Envoltorio de normalización espectral

SpectralNormalization es una envoltura capa Keras. Se puede aplicar a una capa densa existente como esta:

dense = tf.keras.layers.Dense(units=10)
dense = nlp_layers.SpectralNormalization(dense, norm_multiplier=0.9)

Normalización espectral regulariza el peso oculto \(W\) al guiar gradualmente su norma espectral (es decir, el valor propio más grande de \(W\)) hacia el valor objetivo norm_multiplier .

La capa del proceso gaussiano (GP)

RandomFeatureGaussianProcess implementa una aproximación basada azar-feature a un modelo de proceso de Gauss que es entrenable de extremo a extremo con una red neural de profundidad. Bajo el capó, la capa de proceso gaussiano implementa una red de dos capas:

\[logits(x) = \Phi(x) \beta, \quad \Phi(x)=\sqrt{\frac{2}{M} } * cos(Wx + b)\]

Aquí \(x\) es la entrada, y \(W\) y \(b\) son pesos congelados inicializados al azar de Gaussian y distribuciones uniformes, respectivamente. (Por lo tanto \(\Phi(x)\) se denominan "características aleatorios".) \(\beta\) es el peso kernel learnable similar a la de una capa densa.

batch_size = 32
input_dim = 1024
num_classes = 10
gp_layer = nlp_layers.RandomFeatureGaussianProcess(units=num_classes,
                                               num_inducing=1024,
                                               normalize_input=False,
                                               scale_random_features=True,
                                               gp_cov_momentum=-1)

Los principales parámetros de las capas GP son:

  • units : La dimensión de los logits de salida.
  • num_inducing : La dimensión \(M\) del peso oculto \(W\). Predeterminado a 1024.
  • normalize_input : si se debe aplicar la normalización capa a la entrada \(x\).
  • scale_random_features : si se debe aplicar la escala \(\sqrt{2/M}\) a la salida oculta.
  • gp_cov_momentum controla cómo se calcula el modelo de covarianza. Si se establece en un valor positivo (por ejemplo, 0,999), la matriz de covarianza se calcula utilizando la actualización de la media móvil basada en el momento (similar a la normalización por lotes). Si se establece en -1, la matriz de covarianza se actualiza sin impulso.

Dada una entrada por lotes con forma (batch_size, input_dim) , la capa de GP devuelve un logits tensor (forma (batch_size, num_classes) ) para la predicción, y también covmat tensor (forma (batch_size, batch_size) ) que es la matriz de covarianza de la posterior logits por lotes.

embedding = tf.random.normal(shape=(batch_size, input_dim))

logits, covmat = gp_layer(embedding)

Teóricamente, es posible extender el algoritmo para calcular diferentes valores de varianza para las diferentes clases (como se introdujo en el papel SNGP originales ). Sin embargo, esto es difícil de escalar a problemas con grandes espacios de salida (por ejemplo, ImageNet o modelado de lenguaje).

El modelo SNGP completo

Dada la clase base DeepResNet , el modelo SNGP se puede implementar fácilmente mediante la modificación de capas ocultas y de salida de la red residual. Para la compatibilidad con Keras model.fit() de la API, también modificar el modelo de call() método por lo que sólo salidas logits durante el entrenamiento.

class DeepResNetSNGP(DeepResNet):
  def __init__(self, spec_norm_bound=0.9, **kwargs):
    self.spec_norm_bound = spec_norm_bound
    super().__init__(**kwargs)

  def make_dense_layer(self):
    """Applies spectral normalization to the hidden layer."""
    dense_layer = super().make_dense_layer()
    return nlp_layers.SpectralNormalization(
        dense_layer, norm_multiplier=self.spec_norm_bound)

  def make_output_layer(self, num_classes):
    """Uses Gaussian process as the output layer."""
    return nlp_layers.RandomFeatureGaussianProcess(
        num_classes, 
        gp_cov_momentum=-1,
        **self.classifier_kwargs)

  def call(self, inputs, training=False, return_covmat=False):
    # Gets logits and covariance matrix from GP layer.
    logits, covmat = super().call(inputs)

    # Returns only logits during training.
    if not training and return_covmat:
      return logits, covmat

    return logits

Utilice la misma arquitectura que el modelo determinista.

resnet_config
{'num_classes': 2, 'num_layers': 6, 'num_hidden': 128}
sngp_model = DeepResNetSNGP(**resnet_config)
sngp_model.build((None, 2))
sngp_model.summary()
Model: "deep_res_net_sngp"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_9 (Dense)             multiple                  384       
                                                                 
 spectral_normalization_1 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_2 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_3 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_4 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_5 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_6 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 random_feature_gaussian_pro  multiple                 1182722   
 cess (RandomFeatureGaussian                                     
 Process)                                                        
                                                                 
=================================================================
Total params: 1,283,714
Trainable params: 101,120
Non-trainable params: 1,182,594
_________________________________________________________________

Implemente una devolución de llamada de Keras para restablecer la matriz de covarianza al comienzo de una nueva época.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()

Añadir esta devolución de llamada a la DeepResNetSNGP clase del modelo.

class DeepResNetSNGPWithCovReset(DeepResNetSNGP):
  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs["callbacks"] = list(kwargs.get("callbacks", []))
    kwargs["callbacks"].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Modelo de tren

Utilice tf.keras.model.fit para entrenar el modelo.

sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)
sngp_model.compile(**train_config)
sngp_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 1s 5ms/step - loss: 0.6318 - sparse_categorical_accuracy: 0.9045
Epoch 2/100
8/8 [==============================] - 0s 5ms/step - loss: 0.5380 - sparse_categorical_accuracy: 0.9840
Epoch 3/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4815 - sparse_categorical_accuracy: 0.9970
Epoch 4/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4410 - sparse_categorical_accuracy: 0.9960
Epoch 5/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4062 - sparse_categorical_accuracy: 0.9980
Epoch 6/100
8/8 [==============================] - 0s 5ms/step - loss: 0.3792 - sparse_categorical_accuracy: 0.9970
Epoch 7/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3532 - sparse_categorical_accuracy: 0.9990
Epoch 8/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3322 - sparse_categorical_accuracy: 0.9970
Epoch 9/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3121 - sparse_categorical_accuracy: 0.9980
Epoch 10/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2969 - sparse_categorical_accuracy: 0.9970
Epoch 11/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2808 - sparse_categorical_accuracy: 0.9990
Epoch 12/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2671 - sparse_categorical_accuracy: 0.9980
Epoch 13/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2544 - sparse_categorical_accuracy: 0.9970
Epoch 14/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2423 - sparse_categorical_accuracy: 0.9970
Epoch 15/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2322 - sparse_categorical_accuracy: 0.9980
Epoch 16/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2230 - sparse_categorical_accuracy: 0.9980
Epoch 17/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2134 - sparse_categorical_accuracy: 0.9990
Epoch 18/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2061 - sparse_categorical_accuracy: 0.9960
Epoch 19/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1979 - sparse_categorical_accuracy: 0.9970
Epoch 20/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1908 - sparse_categorical_accuracy: 0.9980
Epoch 21/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1850 - sparse_categorical_accuracy: 0.9960
Epoch 22/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1777 - sparse_categorical_accuracy: 0.9960
Epoch 23/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1727 - sparse_categorical_accuracy: 0.9980
Epoch 24/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1669 - sparse_categorical_accuracy: 0.9960
Epoch 25/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1606 - sparse_categorical_accuracy: 0.9980
Epoch 26/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1554 - sparse_categorical_accuracy: 0.9970
Epoch 27/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1515 - sparse_categorical_accuracy: 0.9970
Epoch 28/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1478 - sparse_categorical_accuracy: 0.9960
Epoch 29/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1416 - sparse_categorical_accuracy: 0.9980
Epoch 30/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1396 - sparse_categorical_accuracy: 0.9980
Epoch 31/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1346 - sparse_categorical_accuracy: 0.9960
Epoch 32/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1304 - sparse_categorical_accuracy: 0.9970
Epoch 33/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1276 - sparse_categorical_accuracy: 0.9980
Epoch 34/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1234 - sparse_categorical_accuracy: 0.9990
Epoch 35/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1204 - sparse_categorical_accuracy: 0.9970
Epoch 36/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1176 - sparse_categorical_accuracy: 0.9970
Epoch 37/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1148 - sparse_categorical_accuracy: 0.9980
Epoch 38/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1122 - sparse_categorical_accuracy: 0.9980
Epoch 39/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1085 - sparse_categorical_accuracy: 0.9980
Epoch 40/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1063 - sparse_categorical_accuracy: 0.9990
Epoch 41/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1040 - sparse_categorical_accuracy: 0.9970
Epoch 42/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1005 - sparse_categorical_accuracy: 0.9990
Epoch 43/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0995 - sparse_categorical_accuracy: 0.9980
Epoch 44/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0962 - sparse_categorical_accuracy: 0.9990
Epoch 45/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0941 - sparse_categorical_accuracy: 0.9990
Epoch 46/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0918 - sparse_categorical_accuracy: 0.9980
Epoch 47/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0898 - sparse_categorical_accuracy: 0.9980
Epoch 48/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0883 - sparse_categorical_accuracy: 0.9990
Epoch 49/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0853 - sparse_categorical_accuracy: 0.9990
Epoch 50/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0840 - sparse_categorical_accuracy: 0.9990
Epoch 51/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0827 - sparse_categorical_accuracy: 1.0000
Epoch 52/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0793 - sparse_categorical_accuracy: 0.9990
Epoch 53/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0782 - sparse_categorical_accuracy: 1.0000
Epoch 54/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0767 - sparse_categorical_accuracy: 0.9990
Epoch 55/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9980
Epoch 56/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0737 - sparse_categorical_accuracy: 0.9990
Epoch 57/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0723 - sparse_categorical_accuracy: 1.0000
Epoch 58/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0714 - sparse_categorical_accuracy: 0.9990
Epoch 59/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0695 - sparse_categorical_accuracy: 1.0000
Epoch 60/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0681 - sparse_categorical_accuracy: 1.0000
Epoch 61/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0667 - sparse_categorical_accuracy: 1.0000
Epoch 62/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0648 - sparse_categorical_accuracy: 1.0000
Epoch 63/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0637 - sparse_categorical_accuracy: 1.0000
Epoch 64/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0628 - sparse_categorical_accuracy: 1.0000
Epoch 65/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0610 - sparse_categorical_accuracy: 1.0000
Epoch 66/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0595 - sparse_categorical_accuracy: 1.0000
Epoch 67/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0593 - sparse_categorical_accuracy: 1.0000
Epoch 68/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0581 - sparse_categorical_accuracy: 0.9990
Epoch 69/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0569 - sparse_categorical_accuracy: 1.0000
Epoch 70/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0559 - sparse_categorical_accuracy: 1.0000
Epoch 71/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0547 - sparse_categorical_accuracy: 1.0000
Epoch 72/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0534 - sparse_categorical_accuracy: 1.0000
Epoch 73/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0531 - sparse_categorical_accuracy: 1.0000
Epoch 74/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0518 - sparse_categorical_accuracy: 1.0000
Epoch 75/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0508 - sparse_categorical_accuracy: 1.0000
Epoch 76/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0496 - sparse_categorical_accuracy: 1.0000
Epoch 77/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0486 - sparse_categorical_accuracy: 1.0000
Epoch 78/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0482 - sparse_categorical_accuracy: 1.0000
Epoch 79/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0478 - sparse_categorical_accuracy: 1.0000
Epoch 80/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0467 - sparse_categorical_accuracy: 1.0000
Epoch 81/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0462 - sparse_categorical_accuracy: 1.0000
Epoch 82/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0453 - sparse_categorical_accuracy: 1.0000
Epoch 83/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0440 - sparse_categorical_accuracy: 1.0000
Epoch 84/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0435 - sparse_categorical_accuracy: 1.0000
Epoch 85/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0427 - sparse_categorical_accuracy: 1.0000
Epoch 86/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0427 - sparse_categorical_accuracy: 1.0000
Epoch 87/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0411 - sparse_categorical_accuracy: 1.0000
Epoch 88/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0408 - sparse_categorical_accuracy: 1.0000
Epoch 89/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0403 - sparse_categorical_accuracy: 1.0000
Epoch 90/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0395 - sparse_categorical_accuracy: 1.0000
Epoch 91/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0395 - sparse_categorical_accuracy: 1.0000
Epoch 92/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0385 - sparse_categorical_accuracy: 1.0000
Epoch 93/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0377 - sparse_categorical_accuracy: 1.0000
Epoch 94/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0374 - sparse_categorical_accuracy: 1.0000
Epoch 95/100
8/8 [==============================] - 0s 6ms/step - loss: 0.0370 - sparse_categorical_accuracy: 1.0000
Epoch 96/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0365 - sparse_categorical_accuracy: 1.0000
Epoch 97/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0359 - sparse_categorical_accuracy: 1.0000
Epoch 98/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0351 - sparse_categorical_accuracy: 1.0000
Epoch 99/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0348 - sparse_categorical_accuracy: 1.0000
Epoch 100/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0341 - sparse_categorical_accuracy: 1.0000
<keras.callbacks.History at 0x7ffb183d11d0>

Visualiza la incertidumbre

Primero calcule los logits predictivos y las varianzas.

sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_variance = tf.linalg.diag_part(sngp_covmat)[:, None]

Ahora calcule la probabilidad predictiva posterior. El método clásico para calcular la probabilidad predictiva de un modelo probabilístico es utilizar el muestreo de Monte Carlo, es decir,

\[E(p(x)) = \frac{1}{M} \sum_{m=1}^M logit_m(x), \]

donde \(M\) es el tamaño de la muestra, y \(logit_m(x)\) son muestras aleatorias de la SNGP posterior \(MultivariateNormal\)( sngp_logits , sngp_covmat ). Sin embargo, este enfoque puede resultar lento para aplicaciones sensibles a la latencia, como la conducción autónoma o las ofertas en tiempo real. En su lugar, se puede aproximar \(E(p(x))\) utilizando el método de campo medio :

\[E(p(x)) \approx softmax(\frac{logit(x)}{\sqrt{1+ \lambda * \sigma^2(x)} })\]

donde \(\sigma^2(x)\) es la varianza SNGP y \(\lambda\) se elige a menudo como \(\pi/8\) o \(3/\pi^2\).

sngp_logits_adjusted = sngp_logits / tf.sqrt(1. + (np.pi / 8.) * sngp_variance)
sngp_probs = tf.nn.softmax(sngp_logits_adjusted, axis=-1)[:, 0]

Este método de campo medio se implementa como una función integrada de layers.gaussian_process.mean_field_logits :

def compute_posterior_mean_probability(logits, covmat, lambda_param=np.pi / 8.):
  # Computes uncertainty-adjusted logits using the built-in method.
  logits_adjusted = nlp_layers.gaussian_process.mean_field_logits(
      logits, covmat, mean_field_factor=lambda_param)

  return tf.nn.softmax(logits_adjusted, axis=-1)[:, 0]
sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

Resumen de SNGP

Pon todo junto. Todo el procedimiento (entrenamiento, evaluación y cálculo de la incertidumbre) se puede realizar en solo cinco líneas:

def train_and_test_sngp(train_examples, test_examples):
  sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)

  sngp_model.compile(**train_config)
  sngp_model.fit(train_examples, train_labels, verbose=0, **fit_config)

  sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
  sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

  return sngp_probs
sngp_probs = train_and_test_sngp(train_examples, test_examples)

Visualice la probabilidad de clase (izquierda) y la incertidumbre predictiva (derecha) del modelo SNGP.

plot_predictions(sngp_probs, model_name="SNGP")

png

Recuerde que en la gráfica de probabilidad de clase (izquierda), el amarillo y el morado son probabilidades de clase. Cuando está cerca del dominio de datos de entrenamiento, SNGP clasifica correctamente los ejemplos con alta confianza (es decir, asigna una probabilidad cercana a 0 o 1). Cuando está lejos de los datos de entrenamiento, SNGP gradualmente se vuelve menos seguro y su probabilidad predictiva se acerca a 0.5 mientras que la incertidumbre del modelo (normalizado) aumenta a 1.

Compare esto con la superficie de incertidumbre del modelo determinista:

plot_predictions(resnet_probs, model_name="Deterministic")

png

Como se mencionó anteriormente, un modelo determinista no es consciente de distancia. Su incertidumbre se define por la distancia entre el ejemplo de prueba y el límite de decisión. Esto lleva al modelo a producir predicciones con exceso de confianza para los ejemplos fuera del dominio (rojo).

Comparación con otros enfoques de incertidumbre

Esta sección compara la incertidumbre de SNGP con Monte Carlo deserción y ensemble profundo .

Ambos métodos se basan en el promedio de Monte Carlo de múltiples pasadas hacia adelante de modelos deterministas. En primer lugar establecer el tamaño del conjunto \(M\).

num_ensemble = 10

Abandono de Montecarlo

Dada una red neuronal entrenada con capas de deserción, Monte Carlo desactivará calcula la probabilidad de predicción media

\[E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))\]

promediando sobre múltiples Dropout habilitado adelante pasa \(\{logit_m(x)\}_{m=1}^M\).

def mc_dropout_sampling(test_examples):
  # Enable dropout during inference.
  return resnet_model(test_examples, training=True)
# Monte Carlo dropout inference.
dropout_logit_samples = [mc_dropout_sampling(test_examples) for _ in range(num_ensemble)]
dropout_prob_samples = [tf.nn.softmax(dropout_logits, axis=-1)[:, 0] for dropout_logits in dropout_logit_samples]
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
plot_predictions(dropout_probs, model_name="MC Dropout")

png

Conjunto profundo

Ensemble profunda es un método del estado de la técnica (pero caro) para la incertidumbre aprendizaje profundo. Para entrenar a un conjunto profundo, primer tren \(M\) miembros del conjunto.

# Deep ensemble training
resnet_ensemble = []
for _ in range(num_ensemble):
  resnet_model = DeepResNet(**resnet_config)
  resnet_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
  resnet_model.fit(train_examples, train_labels, verbose=0, **fit_config)  

  resnet_ensemble.append(resnet_model)

Logits recoger y calcular la media predctive probabilidad \(E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))\).

# Deep ensemble inference
ensemble_logit_samples = [model(test_examples) for model in resnet_ensemble]
ensemble_prob_samples = [tf.nn.softmax(logits, axis=-1)[:, 0] for logits in ensemble_logit_samples]
ensemble_probs = tf.reduce_mean(ensemble_prob_samples, axis=0)
plot_predictions(ensemble_probs, model_name="Deep ensemble")

png

Tanto MC Dropout como Deep ensemble mejoran la capacidad de incertidumbre de un modelo al hacer que el límite de decisión sea menos seguro. Sin embargo, ambos heredan la limitación determinista de la red profunda al carecer de conciencia de la distancia.

Resumen

En este tutorial, tienes:

  • Implementé un modelo SNGP en un clasificador profundo para mejorar su conocimiento de la distancia.
  • Formado el modelo SNGP de extremo a extremo que utilizan Keras model.fit() API.
  • Visualizó el comportamiento de incertidumbre de SNGP.
  • Se comparó el comportamiento de la incertidumbre entre los modelos SNGP, de abandono de Monte Carlo y de conjuntos profundos.

Recursos y lectura adicional