Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Clasificación en datos desequilibrados

Ver en TensorFlow.org Ejecutar en Google Colab Ver código fuente en GitHub Descargar cuaderno

Este tutorial muestra cómo clasificar un conjunto de datos altamente desequilibrado en el que el número de ejemplos en una clase supera ampliamente a los ejemplos en otra. Trabajará con el conjunto de datos de detección de fraude con tarjeta de crédito alojado en Kaggle. El objetivo es detectar apenas 492 transacciones fraudulentas de 284.807 transacciones en total. Utilizará Keras para definir el modelo y los pesos de clase para ayudar al modelo a aprender de los datos desequilibrados. .

Este tutorial contiene código completo para:

  • Cargue un archivo CSV usando Pandas.
  • Crear trenes, validaciones y conjuntos de pruebas.
  • Defina y entrene un modelo usando Keras (incluida la configuración de pesos de clase).
  • Evalúe el modelo utilizando varias métricas (incluida la precisión y la recuperación).
  • Pruebe técnicas comunes para manejar datos desequilibrados como:
    • Ponderación de clase
    • Sobremuestreo

Preparar

 import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
 
 mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
 

Procesamiento de datos y exploración

Descargue el conjunto de datos de Kaggle Credit Card Fraud

Pandas es una biblioteca de Python con muchas utilidades útiles para cargar y trabajar con datos estructurados y se puede usar para descargar CSV en un marco de datos.

 file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
 
 raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
 

Examinar el desequilibrio de la etiqueta de clase.

Veamos el desequilibrio del conjunto de datos:

 neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
 
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)


Esto muestra la pequeña fracción de muestras positivas.

Limpiar, dividir y normalizar los datos.

Los datos sin procesar tienen algunos problemas. Primero, las columnas de Time y Amount son demasiado variables para usarlas directamente. Suelte la columna Time (ya que no está claro lo que significa) y tome el registro de la columna Amount para reducir su rango.

 cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps=0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)
 

Divida el conjunto de datos en conjuntos de tren, validación y prueba. El conjunto de validación se usa durante el ajuste del modelo para evaluar la pérdida y cualquier métrica, sin embargo, el modelo no se ajusta a estos datos. El conjunto de prueba no se utiliza por completo durante la fase de entrenamiento y solo se usa al final para evaluar qué tan bien el modelo se generaliza a los datos nuevos. Esto es especialmente importante con los conjuntos de datos desequilibrados donde el sobreajuste es una preocupación importante por la falta de datos de capacitación.

 # Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)
 

Normalice las funciones de entrada con el sklearn StandardScaler. Esto establecerá la media en 0 y la desviación estándar en 1.

 scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)

 
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

Mira la distribución de datos

Luego compare las distribuciones de los ejemplos positivos y negativos sobre algunas características. Las buenas preguntas para hacerse en este momento son:

  • ¿Estas distribuciones tienen sentido?
    • Si. Ha normalizado la entrada y estos se concentran principalmente en el rango de +/- 2 .
  • ¿Puedes ver la diferencia entre las distribuciones?
    • Sí, los ejemplos positivos contienen una tasa mucho más alta de valores extremos.
 pos_df = pd.DataFrame(train_features[ bool_train_labels], columns = train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns = train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
_ = plt.suptitle("Negative distribution")
 

png

png

Definir el modelo y las métricas.

Defina una función que cree una red neuronal simple con una capa oculta densamente conectada, una capa de deserción para reducir el sobreajuste y una capa sigmoidea de salida que devuelve la probabilidad de que una transacción sea fraudulenta:

 METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics = METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model
 

Comprender métricas útiles

Tenga en cuenta que hay algunas métricas definidas anteriormente que pueden ser calculadas por el modelo que serán útiles al evaluar el rendimiento.

  • Los falsos negativos y los falsos positivos son muestras que se clasificaron incorrectamente
  • Los verdaderos negativos y los verdaderos positivos son muestras que se clasificaron correctamente
  • La precisión es el porcentaje de ejemplos correctamente clasificados> $ \ frac {\ text {muestras verdaderas}} {\ text {muestras totales}} $
  • La precisión es el porcentaje de positivos pronosticados que se clasificaron correctamente> $ \ frac {\ text {verdaderos positivos}} {\ text {verdaderos positivos + falsos positivos}} $
  • La recuperación es el porcentaje de positivos reales que se clasificaron correctamente> $ \ frac {\ text {verdaderos positivos}} {\ text {verdaderos positivos + falsos negativos}} $
  • AUC se refiere al área bajo la curva de una curva característica de funcionamiento del receptor (ROC-AUC). Esta métrica es igual a la probabilidad de que un clasificador clasifique una muestra positiva aleatoria más alta que una muestra negativa aleatoria.

Lee mas:

Modelo de referencia

Construye el modelo

Ahora cree y entrene su modelo utilizando la función que se definió anteriormente. Tenga en cuenta que el modelo se ajusta utilizando un tamaño de lote mayor que el predeterminado de 2048, esto es importante para garantizar que cada lote tenga una posibilidad decente de contener algunas muestras positivas. Si el tamaño del lote fuera demasiado pequeño, probablemente no tendrían transacciones fraudulentas de las cuales aprender.

 EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
 
 model = make_model()
model.summary()
 
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________

Prueba ejecutar el modelo:

 model.predict(train_features[:10])
 
array([[0.5788107 ],
       [0.44979692],
       [0.5427961 ],
       [0.5985188 ],
       [0.7758075 ],
       [0.3417888 ],
       [0.39359283],
       [0.5399953 ],
       [0.3551327 ],
       [0.47230086]], dtype=float32)

Opcional: establezca el sesgo inicial correcto.

Estas conjeturas iniciales no son geniales. Sabes que el conjunto de datos está desequilibrado. Establezca el sesgo de la capa de salida para reflejar eso (consulte: Una receta para entrenar redes neuronales: "iniciar bien" ). Esto puede ayudar con la convergencia inicial.

Con la inicialización de sesgo predeterminada, la pérdida debería ser sobre math.log(2) = 0.69314

 results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
 
Loss: 0.7817

El sesgo correcto para establecer puede derivarse de:

$$ p_0 = pos / (pos + neg) = 1 / (1 + e ^ {- b_0}) $$
$$ b_0 = -log_e (1 / p_0 - 1) $$
$$ b_0 = log_e (pos / neg) $$
 initial_bias = np.log([pos/neg])
initial_bias
 
array([-6.35935934])

Establezca eso como el sesgo inicial, y el modelo dará conjeturas iniciales mucho más razonables.

Debe estar cerca: pos/total = 0.0018

 model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])
 
array([[0.00093563],
       [0.00187903],
       [0.00109238],
       [0.00117128],
       [0.00134988],
       [0.00090826],
       [0.00099455],
       [0.00154405],
       [0.00100204],
       [0.0004291 ]], dtype=float32)

Con esta inicialización, la pérdida inicial debería ser aproximadamente:

$$ - p_0log (p_0) - (1-p_0) log (1-p_0) = 0.01317 $$
 results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
 
Loss: 0.0146

Esta pérdida inicial es aproximadamente 50 veces menor que si hubiera sido con una inicialización ingenua.

De esta manera, el modelo no necesita pasar las primeras épocas simplemente aprendiendo que los ejemplos positivos son poco probables. Esto también facilita la lectura de tramas de la pérdida durante el entrenamiento.

Punto de control de los pesos iniciales

Para hacer que las distintas ejecuciones de entrenamiento sean más comparables, mantenga los pesos de este modelo inicial en un archivo de punto de control y cárguelos en cada modelo antes de entrenar.

 initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights')
model.save_weights(initial_weights)
 

Confirme que la corrección de sesgo ayuda

Antes de continuar, confirme rápidamente que la cuidadosa inicialización de sesgo realmente ayudó.

Entrene el modelo durante 20 épocas, con y sin esta inicialización cuidadosa, y compare las pérdidas:

 model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
 
 model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
 
 def plot_loss(history, label, n):
  # Use a log scale to show the wide range of values.
  plt.semilogy(history.epoch,  history.history['loss'],
               color=colors[n], label='Train '+label)
  plt.semilogy(history.epoch,  history.history['val_loss'],
          color=colors[n], label='Val '+label,
          linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  
  plt.legend()
 
 plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)
 

png

La figura anterior lo deja claro: en términos de pérdida de validación, en este problema, esta inicialización cuidadosa ofrece una clara ventaja.

Entrenar a la modelo

 model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels))
 
Epoch 1/100
90/90 [==============================] - 1s 13ms/step - loss: 0.0112 - tp: 100.0000 - fp: 25.0000 - tn: 227419.0000 - fn: 301.0000 - accuracy: 0.9986 - precision: 0.8000 - recall: 0.2494 - auc: 0.7615 - val_loss: 0.0067 - val_tp: 15.0000 - val_fp: 2.0000 - val_tn: 45480.0000 - val_fn: 72.0000 - val_accuracy: 0.9984 - val_precision: 0.8824 - val_recall: 0.1724 - val_auc: 0.9077
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0075 - tp: 108.0000 - fp: 24.0000 - tn: 181938.0000 - fn: 206.0000 - accuracy: 0.9987 - precision: 0.8182 - recall: 0.3439 - auc: 0.8491 - val_loss: 0.0046 - val_tp: 45.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 42.0000 - val_accuracy: 0.9989 - val_precision: 0.8824 - val_recall: 0.5172 - val_auc: 0.9308
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0065 - tp: 138.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 176.0000 - accuracy: 0.9989 - precision: 0.8364 - recall: 0.4395 - auc: 0.8567 - val_loss: 0.0040 - val_tp: 54.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 33.0000 - val_accuracy: 0.9991 - val_precision: 0.8852 - val_recall: 0.6207 - val_auc: 0.9365
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0060 - tp: 154.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 160.0000 - accuracy: 0.9989 - precision: 0.8235 - recall: 0.4904 - auc: 0.8848 - val_loss: 0.0037 - val_tp: 61.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 26.0000 - val_accuracy: 0.9993 - val_precision: 0.8841 - val_recall: 0.7011 - val_auc: 0.9422
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0057 - tp: 157.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 157.0000 - accuracy: 0.9989 - precision: 0.8135 - recall: 0.5000 - auc: 0.8982 - val_loss: 0.0035 - val_tp: 62.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8857 - val_recall: 0.7126 - val_auc: 0.9422
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0057 - tp: 152.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 162.0000 - accuracy: 0.9989 - precision: 0.8261 - recall: 0.4841 - auc: 0.8934 - val_loss: 0.0033 - val_tp: 65.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8904 - val_recall: 0.7471 - val_auc: 0.9479
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0052 - tp: 174.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 140.0000 - accuracy: 0.9991 - precision: 0.8529 - recall: 0.5541 - auc: 0.8983 - val_loss: 0.0032 - val_tp: 66.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8919 - val_recall: 0.7586 - val_auc: 0.9479
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0054 - tp: 161.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 153.0000 - accuracy: 0.9990 - precision: 0.8342 - recall: 0.5127 - auc: 0.8983 - val_loss: 0.0031 - val_tp: 66.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8919 - val_recall: 0.7586 - val_auc: 0.9479
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 167.0000 - fp: 37.0000 - tn: 181925.0000 - fn: 147.0000 - accuracy: 0.9990 - precision: 0.8186 - recall: 0.5318 - auc: 0.9064 - val_loss: 0.0030 - val_tp: 65.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8904 - val_recall: 0.7471 - val_auc: 0.9479
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 156.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 158.0000 - accuracy: 0.9989 - precision: 0.8211 - recall: 0.4968 - auc: 0.9046 - val_loss: 0.0029 - val_tp: 67.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8933 - val_recall: 0.7701 - val_auc: 0.9479
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0048 - tp: 165.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8376 - recall: 0.5255 - auc: 0.9063 - val_loss: 0.0029 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9479
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0051 - tp: 165.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8250 - recall: 0.5255 - auc: 0.9110 - val_loss: 0.0028 - val_tp: 67.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8933 - val_recall: 0.7701 - val_auc: 0.9480
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 157.0000 - fp: 29.0000 - tn: 181933.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8441 - recall: 0.5000 - auc: 0.9031 - val_loss: 0.0028 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9479
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 160.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 154.0000 - accuracy: 0.9990 - precision: 0.8205 - recall: 0.5096 - auc: 0.8934 - val_loss: 0.0027 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9479
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0049 - tp: 168.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8235 - recall: 0.5350 - auc: 0.9031 - val_loss: 0.0027 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9479
Epoch 16/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0046 - tp: 169.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 145.0000 - accuracy: 0.9990 - precision: 0.8492 - recall: 0.5382 - auc: 0.9143 - val_loss: 0.0027 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 181.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8498 - recall: 0.5764 - auc: 0.9144 - val_loss: 0.0027 - val_tp: 70.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.8974 - val_recall: 0.8046 - val_auc: 0.9537
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 181.0000 - fp: 29.0000 - tn: 181933.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8619 - recall: 0.5764 - auc: 0.9112 - val_loss: 0.0026 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 172.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8431 - recall: 0.5478 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 20/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 177.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 137.0000 - accuracy: 0.9991 - precision: 0.8349 - recall: 0.5637 - auc: 0.9128 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 21/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 176.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8462 - recall: 0.5605 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.9167 - val_recall: 0.7586 - val_auc: 0.9537
Epoch 22/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 163.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 151.0000 - accuracy: 0.9990 - precision: 0.8316 - recall: 0.5191 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 23/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 183.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8281 - recall: 0.5828 - auc: 0.9113 - val_loss: 0.0026 - val_tp: 66.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.9041 - val_recall: 0.7586 - val_auc: 0.9537
Epoch 24/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 168.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8400 - recall: 0.5350 - auc: 0.9128 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 25/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 179.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 135.0000 - accuracy: 0.9991 - precision: 0.8483 - recall: 0.5701 - auc: 0.9161 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 26/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 173.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 141.0000 - accuracy: 0.9990 - precision: 0.8199 - recall: 0.5510 - auc: 0.9208 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 27/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 172.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8431 - recall: 0.5478 - auc: 0.9081 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 28/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 181.0000 - fp: 39.0000 - tn: 181923.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8227 - recall: 0.5764 - auc: 0.9193 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9189 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 29/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 177.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 137.0000 - accuracy: 0.9990 - precision: 0.8233 - recall: 0.5637 - auc: 0.9305 - val_loss: 0.0025 - val_tp: 67.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.9054 - val_recall: 0.7701 - val_auc: 0.9538
Epoch 30/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 168.0000 - fp: 31.0000 - tn: 181931.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8442 - recall: 0.5350 - auc: 0.9161 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 31/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 172.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8309 - recall: 0.5478 - auc: 0.9176 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 32/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 188.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8507 - recall: 0.5987 - auc: 0.9162 - val_loss: 0.0025 - val_tp: 70.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9091 - val_recall: 0.8046 - val_auc: 0.9538
Epoch 33/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 184.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8720 - recall: 0.5860 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 72.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.8276 - val_auc: 0.9537
Epoch 34/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 185.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 129.0000 - accuracy: 0.9991 - precision: 0.8486 - recall: 0.5892 - auc: 0.9273 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 35/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 178.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8318 - recall: 0.5669 - auc: 0.9160 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 36/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 171.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 143.0000 - accuracy: 0.9990 - precision: 0.8382 - recall: 0.5446 - auc: 0.9192 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 37/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 189.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 125.0000 - accuracy: 0.9991 - precision: 0.8438 - recall: 0.6019 - auc: 0.9242 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 38/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 185.0000 - fp: 25.0000 - tn: 181937.0000 - fn: 129.0000 - accuracy: 0.9992 - precision: 0.8810 - recall: 0.5892 - auc: 0.9176 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 39/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 181.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8380 - recall: 0.5764 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9189 - val_recall: 0.7816 - val_auc: 0.9538
Epoch 40/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 175.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 139.0000 - accuracy: 0.9991 - precision: 0.8537 - recall: 0.5573 - auc: 0.9209 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 41/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 180.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 134.0000 - accuracy: 0.9991 - precision: 0.8491 - recall: 0.5732 - auc: 0.9320 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 42/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 188.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8468 - recall: 0.5987 - auc: 0.9209 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 43/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 176.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8421 - recall: 0.5605 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 44/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 172.0000 - fp: 37.0000 - tn: 181925.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8230 - recall: 0.5478 - auc: 0.9129 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 45/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 175.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 139.0000 - accuracy: 0.9990 - precision: 0.8294 - recall: 0.5573 - auc: 0.9368 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 46/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 176.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8421 - recall: 0.5605 - auc: 0.9240 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 47/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 178.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8683 - recall: 0.5669 - auc: 0.9273 - val_loss: 0.0025 - val_tp: 72.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.8276 - val_auc: 0.9537
Epoch 48/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 198.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8534 - recall: 0.6306 - auc: 0.9256 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 5.0000 - val_tn: 45477.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9315 - val_recall: 0.7816 - val_auc: 0.9538
Epoch 49/100
85/90 [===========================>..] - ETA: 0s - loss: 0.0043 - tp: 162.0000 - fp: 29.0000 - tn: 173750.0000 - fn: 139.0000 - accuracy: 0.9990 - precision: 0.8482 - recall: 0.5382 - auc: 0.9157Restoring model weights from the end of the best epoch.
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 171.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 143.0000 - accuracy: 0.9991 - precision: 0.8507 - recall: 0.5446 - auc: 0.9191 - val_loss: 0.0024 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 00049: early stopping

Verifique el historial de entrenamiento

En esta sección, generará gráficos de la precisión y pérdida de su modelo en el conjunto de entrenamiento y validación. Estos son útiles para verificar el sobreajuste, del que puede obtener más información en este tutorial .

Además, puede generar estos gráficos para cualquiera de las métricas que creó anteriormente. Los falsos negativos se incluyen como ejemplo.

 def plot_metrics(history):
  metrics =  ['loss', 'auc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch,  history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()

 
 plot_metrics(baseline_history)
 

png

Evaluar métricas

Puede usar una matriz de confusión para resumir las etiquetas reales frente a las predichas, donde el eje X es la etiqueta predicha y el eje Y es la etiqueta real.

 train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
 
 def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))
 

Evalúe su modelo en el conjunto de datos de prueba y muestre los resultados de las métricas que creó anteriormente.

 baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
 
loss :  0.002310588490217924
tp :  69.0
fp :  5.0
tn :  56866.0
fn :  22.0
accuracy :  0.9995260238647461
precision :  0.9324324131011963
recall :  0.7582417726516724
auc :  0.9557874202728271

Legitimate Transactions Detected (True Negatives):  56866
Legitimate Transactions Incorrectly Detected (False Positives):  5
Fraudulent Transactions Missed (False Negatives):  22
Fraudulent Transactions Detected (True Positives):  69
Total Fraudulent Transactions:  91

png

Si el modelo hubiera predicho todo perfectamente, esta sería una matriz diagonal donde los valores fuera de la diagonal principal, que indican predicciones incorrectas, serían cero. En este caso, la matriz muestra que tiene relativamente pocos falsos positivos, lo que significa que hubo relativamente pocas transacciones legítimas que se marcaron incorrectamente. Sin embargo, es probable que desee tener aún menos falsos negativos a pesar del costo de aumentar el número de falsos positivos. Esta compensación puede ser preferible porque los falsos negativos permitirían que se realicen transacciones fraudulentas, mientras que los falsos positivos pueden hacer que se envíe un correo electrónico a un cliente para pedirle que verifique la actividad de su tarjeta.

Trazar el ROC

Ahora traza el ROC . Este gráfico es útil porque muestra, de un vistazo, el rango de rendimiento que el modelo puede alcanzar simplemente ajustando el umbral de salida.

 def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
 
 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa50c5adef0>

png

Parece que la precisión es relativamente alta, pero la recuperación y el área bajo la curva ROC (AUC) no son tan altas como te gustaría. Los clasificadores a menudo enfrentan desafíos cuando intentan maximizar tanto la precisión como la recuperación, lo cual es especialmente cierto cuando se trabaja con conjuntos de datos desequilibrados. Es importante tener en cuenta los costos de los diferentes tipos de errores en el contexto del problema que le interesa. En este ejemplo, un falso negativo (se pierde una transacción fraudulenta) puede tener un costo financiero, mientras que un falso positivo (una transacción se marca incorrectamente como fraudulento) puede disminuir la felicidad del usuario.

Pesas de clase

Calcular pesos de clase

El objetivo es identificar transacciones fraudulentas, pero no tiene muchas de esas muestras positivas con las que trabajar, por lo que desearía que el clasificador sopesara los pocos ejemplos disponibles. Puede hacerlo pasando los pesos de Keras para cada clase a través de un parámetro. Esto hará que el modelo "preste más atención" a los ejemplos de una clase subrepresentada.

 # Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
 
Weight for class 0: 0.50
Weight for class 1: 289.44

Entrena a una modelo con pesas de clase

Ahora intente volver a entrenar y evaluar el modelo con pesos de clase para ver cómo eso afecta las predicciones.

 weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight) 
 
Epoch 1/100
90/90 [==============================] - 1s 15ms/step - loss: 2.5149 - tp: 105.0000 - fp: 66.0000 - tn: 238767.0000 - fn: 300.0000 - accuracy: 0.9985 - precision: 0.6140 - recall: 0.2593 - auc: 0.7803 - val_loss: 0.0067 - val_tp: 25.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 62.0000 - val_accuracy: 0.9985 - val_precision: 0.8065 - val_recall: 0.2874 - val_auc: 0.9211
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 1.2482 - tp: 145.0000 - fp: 124.0000 - tn: 181838.0000 - fn: 169.0000 - accuracy: 0.9984 - precision: 0.5390 - recall: 0.4618 - auc: 0.8560 - val_loss: 0.0062 - val_tp: 68.0000 - val_fp: 12.0000 - val_tn: 45470.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8500 - val_recall: 0.7816 - val_auc: 0.9408
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.8972 - tp: 177.0000 - fp: 237.0000 - tn: 181725.0000 - fn: 137.0000 - accuracy: 0.9979 - precision: 0.4275 - recall: 0.5637 - auc: 0.8876 - val_loss: 0.0079 - val_tp: 73.0000 - val_fp: 16.0000 - val_tn: 45466.0000 - val_fn: 14.0000 - val_accuracy: 0.9993 - val_precision: 0.8202 - val_recall: 0.8391 - val_auc: 0.9518
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6983 - tp: 210.0000 - fp: 387.0000 - tn: 181575.0000 - fn: 104.0000 - accuracy: 0.9973 - precision: 0.3518 - recall: 0.6688 - auc: 0.9028 - val_loss: 0.0098 - val_tp: 74.0000 - val_fp: 19.0000 - val_tn: 45463.0000 - val_fn: 13.0000 - val_accuracy: 0.9993 - val_precision: 0.7957 - val_recall: 0.8506 - val_auc: 0.9600
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6417 - tp: 220.0000 - fp: 583.0000 - tn: 181379.0000 - fn: 94.0000 - accuracy: 0.9963 - precision: 0.2740 - recall: 0.7006 - auc: 0.9084 - val_loss: 0.0119 - val_tp: 74.0000 - val_fp: 25.0000 - val_tn: 45457.0000 - val_fn: 13.0000 - val_accuracy: 0.9992 - val_precision: 0.7475 - val_recall: 0.8506 - val_auc: 0.9777
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5846 - tp: 232.0000 - fp: 977.0000 - tn: 180985.0000 - fn: 82.0000 - accuracy: 0.9942 - precision: 0.1919 - recall: 0.7389 - auc: 0.9048 - val_loss: 0.0148 - val_tp: 74.0000 - val_fp: 34.0000 - val_tn: 45448.0000 - val_fn: 13.0000 - val_accuracy: 0.9990 - val_precision: 0.6852 - val_recall: 0.8506 - val_auc: 0.9802
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5404 - tp: 234.0000 - fp: 1464.0000 - tn: 180498.0000 - fn: 80.0000 - accuracy: 0.9915 - precision: 0.1378 - recall: 0.7452 - auc: 0.9190 - val_loss: 0.0183 - val_tp: 74.0000 - val_fp: 50.0000 - val_tn: 45432.0000 - val_fn: 13.0000 - val_accuracy: 0.9986 - val_precision: 0.5968 - val_recall: 0.8506 - val_auc: 0.9823
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4714 - tp: 241.0000 - fp: 1862.0000 - tn: 180100.0000 - fn: 73.0000 - accuracy: 0.9894 - precision: 0.1146 - recall: 0.7675 - auc: 0.9252 - val_loss: 0.0225 - val_tp: 76.0000 - val_fp: 84.0000 - val_tn: 45398.0000 - val_fn: 11.0000 - val_accuracy: 0.9979 - val_precision: 0.4750 - val_recall: 0.8736 - val_auc: 0.9851
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4329 - tp: 247.0000 - fp: 2508.0000 - tn: 179454.0000 - fn: 67.0000 - accuracy: 0.9859 - precision: 0.0897 - recall: 0.7866 - auc: 0.9345 - val_loss: 0.0282 - val_tp: 76.0000 - val_fp: 170.0000 - val_tn: 45312.0000 - val_fn: 11.0000 - val_accuracy: 0.9960 - val_precision: 0.3089 - val_recall: 0.8736 - val_auc: 0.9873
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4467 - tp: 249.0000 - fp: 3175.0000 - tn: 178787.0000 - fn: 65.0000 - accuracy: 0.9822 - precision: 0.0727 - recall: 0.7930 - auc: 0.9210 - val_loss: 0.0341 - val_tp: 78.0000 - val_fp: 282.0000 - val_tn: 45200.0000 - val_fn: 9.0000 - val_accuracy: 0.9936 - val_precision: 0.2167 - val_recall: 0.8966 - val_auc: 0.9881
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3947 - tp: 260.0000 - fp: 3569.0000 - tn: 178393.0000 - fn: 54.0000 - accuracy: 0.9801 - precision: 0.0679 - recall: 0.8280 - auc: 0.9290 - val_loss: 0.0394 - val_tp: 78.0000 - val_fp: 346.0000 - val_tn: 45136.0000 - val_fn: 9.0000 - val_accuracy: 0.9922 - val_precision: 0.1840 - val_recall: 0.8966 - val_auc: 0.9877
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3694 - tp: 257.0000 - fp: 4294.0000 - tn: 177668.0000 - fn: 57.0000 - accuracy: 0.9761 - precision: 0.0565 - recall: 0.8185 - auc: 0.9418 - val_loss: 0.0473 - val_tp: 78.0000 - val_fp: 504.0000 - val_tn: 44978.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 0.1340 - val_recall: 0.8966 - val_auc: 0.9879
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3479 - tp: 262.0000 - fp: 4886.0000 - tn: 177076.0000 - fn: 52.0000 - accuracy: 0.9729 - precision: 0.0509 - recall: 0.8344 - auc: 0.9403 - val_loss: 0.0539 - val_tp: 78.0000 - val_fp: 586.0000 - val_tn: 44896.0000 - val_fn: 9.0000 - val_accuracy: 0.9869 - val_precision: 0.1175 - val_recall: 0.8966 - val_auc: 0.9881
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3653 - tp: 263.0000 - fp: 5360.0000 - tn: 176602.0000 - fn: 51.0000 - accuracy: 0.9703 - precision: 0.0468 - recall: 0.8376 - auc: 0.9370 - val_loss: 0.0610 - val_tp: 78.0000 - val_fp: 664.0000 - val_tn: 44818.0000 - val_fn: 9.0000 - val_accuracy: 0.9852 - val_precision: 0.1051 - val_recall: 0.8966 - val_auc: 0.9876
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3673 - tp: 262.0000 - fp: 5820.0000 - tn: 176142.0000 - fn: 52.0000 - accuracy: 0.9678 - precision: 0.0431 - recall: 0.8344 - auc: 0.9316 - val_loss: 0.0658 - val_tp: 78.0000 - val_fp: 715.0000 - val_tn: 44767.0000 - val_fn: 9.0000 - val_accuracy: 0.9841 - val_precision: 0.0984 - val_recall: 0.8966 - val_auc: 0.9877
Epoch 16/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3228 - tp: 262.0000 - fp: 6230.0000 - tn: 175732.0000 - fn: 52.0000 - accuracy: 0.9655 - precision: 0.0404 - recall: 0.8344 - auc: 0.9445 - val_loss: 0.0716 - val_tp: 79.0000 - val_fp: 805.0000 - val_tn: 44677.0000 - val_fn: 8.0000 - val_accuracy: 0.9822 - val_precision: 0.0894 - val_recall: 0.9080 - val_auc: 0.9877
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3299 - tp: 268.0000 - fp: 6572.0000 - tn: 175390.0000 - fn: 46.0000 - accuracy: 0.9637 - precision: 0.0392 - recall: 0.8535 - auc: 0.9423 - val_loss: 0.0757 - val_tp: 81.0000 - val_fp: 846.0000 - val_tn: 44636.0000 - val_fn: 6.0000 - val_accuracy: 0.9813 - val_precision: 0.0874 - val_recall: 0.9310 - val_auc: 0.9878
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2522 - tp: 276.0000 - fp: 6934.0000 - tn: 175028.0000 - fn: 38.0000 - accuracy: 0.9618 - precision: 0.0383 - recall: 0.8790 - auc: 0.9610 - val_loss: 0.0779 - val_tp: 81.0000 - val_fp: 874.0000 - val_tn: 44608.0000 - val_fn: 6.0000 - val_accuracy: 0.9807 - val_precision: 0.0848 - val_recall: 0.9310 - val_auc: 0.9877
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3607 - tp: 264.0000 - fp: 6790.0000 - tn: 175172.0000 - fn: 50.0000 - accuracy: 0.9625 - precision: 0.0374 - recall: 0.8408 - auc: 0.9303 - val_loss: 0.0781 - val_tp: 81.0000 - val_fp: 865.0000 - val_tn: 44617.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0856 - val_recall: 0.9310 - val_auc: 0.9879
Epoch 20/100
89/90 [============================>.] - ETA: 0s - loss: 0.2977 - tp: 269.0000 - fp: 6769.0000 - tn: 175189.0000 - fn: 45.0000 - accuracy: 0.9626 - precision: 0.0382 - recall: 0.8567 - auc: 0.9488Restoring model weights from the end of the best epoch.
90/90 [==============================] - 1s 6ms/step - loss: 0.2977 - tp: 269.0000 - fp: 6769.0000 - tn: 175193.0000 - fn: 45.0000 - accuracy: 0.9626 - precision: 0.0382 - recall: 0.8567 - auc: 0.9488 - val_loss: 0.0780 - val_tp: 81.0000 - val_fp: 853.0000 - val_tn: 44629.0000 - val_fn: 6.0000 - val_accuracy: 0.9811 - val_precision: 0.0867 - val_recall: 0.9310 - val_auc: 0.9879
Epoch 00020: early stopping

Verifique el historial de entrenamiento

 plot_metrics(weighted_history)
 

png

Evaluar métricas

 train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
 
 weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
 
loss :  0.03226418048143387
tp :  82.0
fp :  352.0
tn :  56519.0
fn :  9.0
accuracy :  0.993662416934967
precision :  0.18894009292125702
recall :  0.901098906993866
auc :  0.9671803712844849

Legitimate Transactions Detected (True Negatives):  56519
Legitimate Transactions Incorrectly Detected (False Positives):  352
Fraudulent Transactions Missed (False Negatives):  9
Fraudulent Transactions Detected (True Positives):  82
Total Fraudulent Transactions:  91

png

Aquí puede ver que con los pesos de clase, la precisión y la precisión son más bajas porque hay más falsos positivos, pero a la inversa, el recuerdo y el AUC son más altos porque el modelo también encontró más positivos verdaderos. A pesar de tener una precisión menor, este modelo tiene una mayor recuperación (e identifica transacciones más fraudulentas). Por supuesto, hay un costo para ambos tipos de error (tampoco querrás molestar a los usuarios al marcar demasiadas transacciones legítimas como fraudulentas). Considere cuidadosamente las compensaciones entre estos diferentes tipos de errores para su aplicación.

Trazar el ROC

 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa54c0729e8>

png

Sobremuestreo

Sobremuestrear la clase minoritaria

Un enfoque relacionado sería volver a muestrear el conjunto de datos sobremuestreando la clase minoritaria.

 pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
 

Usando NumPy

Puede equilibrar el conjunto de datos manualmente eligiendo el número correcto de índices aleatorios de los ejemplos positivos:

 ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
 
(181962, 29)
 resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
 
(363924, 29)

Usando tf.data

Si está utilizando tf.data la forma más fácil de producir ejemplos equilibrados es comenzar con un conjunto de datos positive y negative , y fusionarlos. Consulte la guía tf.data para obtener más ejemplos.

 BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)
 

Cada conjunto de datos proporciona pares (feature, label) :

 for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
 
Features:
 [ 0.23104754  0.83661044 -0.31875356  1.9796369   1.28403692  0.07389102
  1.03350673 -0.11568355 -1.54396817  0.88004244 -1.66944551 -0.24324391
  0.45900013  0.14583622 -2.06637388  0.42470592 -0.94489216 -0.83112221
 -1.83416278 -0.34138858  0.14130878  0.51019975  0.08224586  0.6642136
 -1.39031637 -0.42194185  0.22525572  0.28277796 -4.86369823]

Label:  1

Combina los dos juntos usando experimental.sample_from_datasets :

 resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
 
 for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
 
0.49609375

Para usar este conjunto de datos, necesitará la cantidad de pasos por época.

La definición de "época" en este caso es menos clara. Digamos que es la cantidad de lotes necesarios para ver cada ejemplo negativo una vez:

 resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
 
278.0

Entrenar en los datos sobremuestreados

Ahora intente entrenar el modelo con el conjunto de datos muestreados en lugar de usar ponderaciones de clase para ver cómo se comparan estos métodos.

 resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks = [early_stopping],
    validation_data=val_ds)
 
Epoch 1/100
278/278 [==============================] - 6s 23ms/step - loss: 0.4356 - tp: 223484.0000 - fp: 51288.0000 - tn: 290777.0000 - fn: 60757.0000 - accuracy: 0.8211 - precision: 0.8133 - recall: 0.7862 - auc: 0.8933 - val_loss: 0.2172 - val_tp: 79.0000 - val_fp: 1076.0000 - val_tn: 44406.0000 - val_fn: 8.0000 - val_accuracy: 0.9762 - val_precision: 0.0684 - val_recall: 0.9080 - val_auc: 0.9792
Epoch 2/100
278/278 [==============================] - 6s 20ms/step - loss: 0.2177 - tp: 246785.0000 - fp: 12557.0000 - tn: 271871.0000 - fn: 38131.0000 - accuracy: 0.9110 - precision: 0.9516 - recall: 0.8662 - auc: 0.9686 - val_loss: 0.1226 - val_tp: 80.0000 - val_fp: 951.0000 - val_tn: 44531.0000 - val_fn: 7.0000 - val_accuracy: 0.9790 - val_precision: 0.0776 - val_recall: 0.9195 - val_auc: 0.9835
Epoch 3/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1751 - tp: 250631.0000 - fp: 9797.0000 - tn: 275174.0000 - fn: 33742.0000 - accuracy: 0.9235 - precision: 0.9624 - recall: 0.8813 - auc: 0.9810 - val_loss: 0.0940 - val_tp: 82.0000 - val_fp: 966.0000 - val_tn: 44516.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0782 - val_recall: 0.9425 - val_auc: 0.9836
Epoch 4/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1532 - tp: 254169.0000 - fp: 9171.0000 - tn: 275694.0000 - fn: 30310.0000 - accuracy: 0.9307 - precision: 0.9652 - recall: 0.8935 - auc: 0.9861 - val_loss: 0.0802 - val_tp: 82.0000 - val_fp: 918.0000 - val_tn: 44564.0000 - val_fn: 5.0000 - val_accuracy: 0.9797 - val_precision: 0.0820 - val_recall: 0.9425 - val_auc: 0.9847
Epoch 5/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1372 - tp: 257034.0000 - fp: 9061.0000 - tn: 275758.0000 - fn: 27491.0000 - accuracy: 0.9358 - precision: 0.9659 - recall: 0.9034 - auc: 0.9892 - val_loss: 0.0720 - val_tp: 82.0000 - val_fp: 910.0000 - val_tn: 44572.0000 - val_fn: 5.0000 - val_accuracy: 0.9799 - val_precision: 0.0827 - val_recall: 0.9425 - val_auc: 0.9854
Epoch 6/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1260 - tp: 258997.0000 - fp: 9079.0000 - tn: 275819.0000 - fn: 25449.0000 - accuracy: 0.9394 - precision: 0.9661 - recall: 0.9105 - auc: 0.9911 - val_loss: 0.0666 - val_tp: 81.0000 - val_fp: 915.0000 - val_tn: 44567.0000 - val_fn: 6.0000 - val_accuracy: 0.9798 - val_precision: 0.0813 - val_recall: 0.9310 - val_auc: 0.9856
Epoch 7/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1167 - tp: 261100.0000 - fp: 9112.0000 - tn: 276180.0000 - fn: 22952.0000 - accuracy: 0.9437 - precision: 0.9663 - recall: 0.9192 - auc: 0.9925 - val_loss: 0.0623 - val_tp: 81.0000 - val_fp: 911.0000 - val_tn: 44571.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0817 - val_recall: 0.9310 - val_auc: 0.9858
Epoch 8/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1082 - tp: 263945.0000 - fp: 9428.0000 - tn: 275276.0000 - fn: 20695.0000 - accuracy: 0.9471 - precision: 0.9655 - recall: 0.9273 - auc: 0.9937 - val_loss: 0.0587 - val_tp: 81.0000 - val_fp: 910.0000 - val_tn: 44572.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0817 - val_recall: 0.9310 - val_auc: 0.9857
Epoch 9/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1014 - tp: 268108.0000 - fp: 10376.0000 - tn: 274312.0000 - fn: 16548.0000 - accuracy: 0.9527 - precision: 0.9627 - recall: 0.9419 - auc: 0.9944 - val_loss: 0.0543 - val_tp: 80.0000 - val_fp: 873.0000 - val_tn: 44609.0000 - val_fn: 7.0000 - val_accuracy: 0.9807 - val_precision: 0.0839 - val_recall: 0.9195 - val_auc: 0.9857
Epoch 10/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0951 - tp: 277520.0000 - fp: 12692.0000 - tn: 271795.0000 - fn: 7337.0000 - accuracy: 0.9648 - precision: 0.9563 - recall: 0.9742 - auc: 0.9950 - val_loss: 0.0495 - val_tp: 79.0000 - val_fp: 829.0000 - val_tn: 44653.0000 - val_fn: 8.0000 - val_accuracy: 0.9816 - val_precision: 0.0870 - val_recall: 0.9080 - val_auc: 0.9855
Epoch 11/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0895 - tp: 278865.0000 - fp: 12938.0000 - tn: 271719.0000 - fn: 5822.0000 - accuracy: 0.9670 - precision: 0.9557 - recall: 0.9795 - auc: 0.9955 - val_loss: 0.0450 - val_tp: 79.0000 - val_fp: 789.0000 - val_tn: 44693.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0910 - val_recall: 0.9080 - val_auc: 0.9859
Epoch 12/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0842 - tp: 279845.0000 - fp: 13187.0000 - tn: 272121.0000 - fn: 4191.0000 - accuracy: 0.9695 - precision: 0.9550 - recall: 0.9852 - auc: 0.9960 - val_loss: 0.0410 - val_tp: 79.0000 - val_fp: 733.0000 - val_tn: 44749.0000 - val_fn: 8.0000 - val_accuracy: 0.9837 - val_precision: 0.0973 - val_recall: 0.9080 - val_auc: 0.9813
Epoch 13/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0792 - tp: 281765.0000 - fp: 12977.0000 - tn: 271393.0000 - fn: 3209.0000 - accuracy: 0.9716 - precision: 0.9560 - recall: 0.9887 - auc: 0.9963 - val_loss: 0.0389 - val_tp: 79.0000 - val_fp: 721.0000 - val_tn: 44761.0000 - val_fn: 8.0000 - val_accuracy: 0.9840 - val_precision: 0.0988 - val_recall: 0.9080 - val_auc: 0.9814
Epoch 14/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0754 - tp: 281962.0000 - fp: 13026.0000 - tn: 272154.0000 - fn: 2202.0000 - accuracy: 0.9733 - precision: 0.9558 - recall: 0.9923 - auc: 0.9966 - val_loss: 0.0348 - val_tp: 79.0000 - val_fp: 646.0000 - val_tn: 44836.0000 - val_fn: 8.0000 - val_accuracy: 0.9856 - val_precision: 0.1090 - val_recall: 0.9080 - val_auc: 0.9763
Epoch 15/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0722 - tp: 283858.0000 - fp: 12932.0000 - tn: 271419.0000 - fn: 1135.0000 - accuracy: 0.9753 - precision: 0.9564 - recall: 0.9960 - auc: 0.9967 - val_loss: 0.0331 - val_tp: 79.0000 - val_fp: 640.0000 - val_tn: 44842.0000 - val_fn: 8.0000 - val_accuracy: 0.9858 - val_precision: 0.1099 - val_recall: 0.9080 - val_auc: 0.9714
Epoch 16/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0689 - tp: 283059.0000 - fp: 12757.0000 - tn: 273004.0000 - fn: 524.0000 - accuracy: 0.9767 - precision: 0.9569 - recall: 0.9982 - auc: 0.9970 - val_loss: 0.0308 - val_tp: 79.0000 - val_fp: 583.0000 - val_tn: 44899.0000 - val_fn: 8.0000 - val_accuracy: 0.9870 - val_precision: 0.1193 - val_recall: 0.9080 - val_auc: 0.9667
Epoch 17/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0661 - tp: 283879.0000 - fp: 12340.0000 - tn: 272779.0000 - fn: 346.0000 - accuracy: 0.9777 - precision: 0.9583 - recall: 0.9988 - auc: 0.9971 - val_loss: 0.0289 - val_tp: 79.0000 - val_fp: 542.0000 - val_tn: 44940.0000 - val_fn: 8.0000 - val_accuracy: 0.9879 - val_precision: 0.1272 - val_recall: 0.9080 - val_auc: 0.9618
Epoch 18/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0635 - tp: 284858.0000 - fp: 12157.0000 - tn: 272120.0000 - fn: 209.0000 - accuracy: 0.9783 - precision: 0.9591 - recall: 0.9993 - auc: 0.9973 - val_loss: 0.0277 - val_tp: 79.0000 - val_fp: 511.0000 - val_tn: 44971.0000 - val_fn: 8.0000 - val_accuracy: 0.9886 - val_precision: 0.1339 - val_recall: 0.9080 - val_auc: 0.9621
Epoch 19/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0620 - tp: 284459.0000 - fp: 11978.0000 - tn: 272718.0000 - fn: 189.0000 - accuracy: 0.9786 - precision: 0.9596 - recall: 0.9993 - auc: 0.9973 - val_loss: 0.0261 - val_tp: 79.0000 - val_fp: 478.0000 - val_tn: 45004.0000 - val_fn: 8.0000 - val_accuracy: 0.9893 - val_precision: 0.1418 - val_recall: 0.9080 - val_auc: 0.9624
Epoch 20/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0600 - tp: 284950.0000 - fp: 11793.0000 - tn: 272572.0000 - fn: 29.0000 - accuracy: 0.9792 - precision: 0.9603 - recall: 0.9999 - auc: 0.9974 - val_loss: 0.0252 - val_tp: 79.0000 - val_fp: 463.0000 - val_tn: 45019.0000 - val_fn: 8.0000 - val_accuracy: 0.9897 - val_precision: 0.1458 - val_recall: 0.9080 - val_auc: 0.9626
Epoch 21/100
276/278 [============================>.] - ETA: 0s - loss: 0.0581 - tp: 282210.0000 - fp: 11270.0000 - tn: 271768.0000 - fn: 0.0000e+00 - accuracy: 0.9801 - precision: 0.9616 - recall: 1.0000 - auc: 0.9975Restoring model weights from the end of the best epoch.
278/278 [==============================] - 6s 22ms/step - loss: 0.0581 - tp: 284274.0000 - fp: 11360.0000 - tn: 273710.0000 - fn: 0.0000e+00 - accuracy: 0.9800 - precision: 0.9616 - recall: 1.0000 - auc: 0.9975 - val_loss: 0.0241 - val_tp: 79.0000 - val_fp: 444.0000 - val_tn: 45038.0000 - val_fn: 8.0000 - val_accuracy: 0.9901 - val_precision: 0.1511 - val_recall: 0.9080 - val_auc: 0.9628
Epoch 00021: early stopping

Si el proceso de capacitación considerara todo el conjunto de datos en cada actualización de gradiente, este sobremuestreo sería básicamente idéntico a la ponderación de la clase.

Pero al entrenar el modelo por lotes, como lo hizo aquí, los datos sobremuestreados proporcionan una señal de gradiente más suave: en lugar de que cada ejemplo positivo se muestre en un lote con un gran peso, se muestran en muchos lotes diferentes cada vez con un pequeño peso

Esta señal de gradiente más suave facilita el entrenamiento del modelo.

Verifique el historial de entrenamiento

Tenga en cuenta que las distribuciones de métricas serán diferentes aquí, porque los datos de entrenamiento tienen una distribución totalmente diferente de los datos de validación y prueba.

 plot_metrics(resampled_history )
 

png

Volver a entrenar

Debido a que el entrenamiento es más fácil con los datos equilibrados, el procedimiento de entrenamiento anterior puede sobreajustar rápidamente.

Por lo tanto, divida las épocas para dar las callbacks.EarlyStopping . Detener callbacks.EarlyStopping un control más callbacks.EarlyStopping sobre cuándo dejar de entrenar.

 resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch = 20,
    epochs=10*EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_ds))
 
Epoch 1/1000
20/20 [==============================] - 1s 60ms/step - loss: 1.0656 - tp: 9507.0000 - fp: 7370.0000 - tn: 58667.0000 - fn: 10985.0000 - accuracy: 0.7879 - precision: 0.5633 - recall: 0.4639 - auc: 0.8255 - val_loss: 0.5792 - val_tp: 66.0000 - val_fp: 13452.0000 - val_tn: 32030.0000 - val_fn: 21.0000 - val_accuracy: 0.7043 - val_precision: 0.0049 - val_recall: 0.7586 - val_auc: 0.7866
Epoch 2/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.6996 - tp: 13383.0000 - fp: 7208.0000 - tn: 13397.0000 - fn: 6972.0000 - accuracy: 0.6538 - precision: 0.6499 - recall: 0.6575 - auc: 0.7027 - val_loss: 0.5702 - val_tp: 76.0000 - val_fp: 12408.0000 - val_tn: 33074.0000 - val_fn: 11.0000 - val_accuracy: 0.7275 - val_precision: 0.0061 - val_recall: 0.8736 - val_auc: 0.9076
Epoch 3/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.5532 - tp: 15127.0000 - fp: 6665.0000 - tn: 14055.0000 - fn: 5113.0000 - accuracy: 0.7125 - precision: 0.6942 - recall: 0.7474 - auc: 0.7952 - val_loss: 0.5335 - val_tp: 79.0000 - val_fp: 9006.0000 - val_tn: 36476.0000 - val_fn: 8.0000 - val_accuracy: 0.8022 - val_precision: 0.0087 - val_recall: 0.9080 - val_auc: 0.9408
Epoch 4/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.4738 - tp: 16061.0000 - fp: 5669.0000 - tn: 14890.0000 - fn: 4340.0000 - accuracy: 0.7556 - precision: 0.7391 - recall: 0.7873 - auc: 0.8495 - val_loss: 0.4883 - val_tp: 78.0000 - val_fp: 5756.0000 - val_tn: 39726.0000 - val_fn: 9.0000 - val_accuracy: 0.8735 - val_precision: 0.0134 - val_recall: 0.8966 - val_auc: 0.9489
Epoch 5/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.4266 - tp: 16612.0000 - fp: 4719.0000 - tn: 15715.0000 - fn: 3914.0000 - accuracy: 0.7892 - precision: 0.7788 - recall: 0.8093 - auc: 0.8786 - val_loss: 0.4435 - val_tp: 78.0000 - val_fp: 3758.0000 - val_tn: 41724.0000 - val_fn: 9.0000 - val_accuracy: 0.9173 - val_precision: 0.0203 - val_recall: 0.8966 - val_auc: 0.9539
Epoch 6/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.3908 - tp: 16911.0000 - fp: 3861.0000 - tn: 16514.0000 - fn: 3674.0000 - accuracy: 0.8160 - precision: 0.8141 - recall: 0.8215 - auc: 0.8976 - val_loss: 0.4032 - val_tp: 79.0000 - val_fp: 2770.0000 - val_tn: 42712.0000 - val_fn: 8.0000 - val_accuracy: 0.9390 - val_precision: 0.0277 - val_recall: 0.9080 - val_auc: 0.9590
Epoch 7/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3664 - tp: 17049.0000 - fp: 3209.0000 - tn: 17179.0000 - fn: 3523.0000 - accuracy: 0.8356 - precision: 0.8416 - recall: 0.8287 - auc: 0.9108 - val_loss: 0.3682 - val_tp: 79.0000 - val_fp: 2119.0000 - val_tn: 43363.0000 - val_fn: 8.0000 - val_accuracy: 0.9533 - val_precision: 0.0359 - val_recall: 0.9080 - val_auc: 0.9634
Epoch 8/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.3467 - tp: 17100.0000 - fp: 2699.0000 - tn: 17686.0000 - fn: 3475.0000 - accuracy: 0.8493 - precision: 0.8637 - recall: 0.8311 - auc: 0.9193 - val_loss: 0.3373 - val_tp: 79.0000 - val_fp: 1753.0000 - val_tn: 43729.0000 - val_fn: 8.0000 - val_accuracy: 0.9614 - val_precision: 0.0431 - val_recall: 0.9080 - val_auc: 0.9675
Epoch 9/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.3285 - tp: 17043.0000 - fp: 2345.0000 - tn: 18228.0000 - fn: 3344.0000 - accuracy: 0.8611 - precision: 0.8790 - recall: 0.8360 - auc: 0.9271 - val_loss: 0.3104 - val_tp: 79.0000 - val_fp: 1495.0000 - val_tn: 43987.0000 - val_fn: 8.0000 - val_accuracy: 0.9670 - val_precision: 0.0502 - val_recall: 0.9080 - val_auc: 0.9702
Epoch 10/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.3094 - tp: 17322.0000 - fp: 2012.0000 - tn: 18405.0000 - fn: 3221.0000 - accuracy: 0.8722 - precision: 0.8959 - recall: 0.8432 - auc: 0.9361 - val_loss: 0.2865 - val_tp: 79.0000 - val_fp: 1332.0000 - val_tn: 44150.0000 - val_fn: 8.0000 - val_accuracy: 0.9706 - val_precision: 0.0560 - val_recall: 0.9080 - val_auc: 0.9721
Epoch 11/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2962 - tp: 17184.0000 - fp: 1757.0000 - tn: 18853.0000 - fn: 3166.0000 - accuracy: 0.8798 - precision: 0.9072 - recall: 0.8444 - auc: 0.9406 - val_loss: 0.2654 - val_tp: 79.0000 - val_fp: 1228.0000 - val_tn: 44254.0000 - val_fn: 8.0000 - val_accuracy: 0.9729 - val_precision: 0.0604 - val_recall: 0.9080 - val_auc: 0.9739
Epoch 12/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2835 - tp: 17373.0000 - fp: 1543.0000 - tn: 18909.0000 - fn: 3135.0000 - accuracy: 0.8858 - precision: 0.9184 - recall: 0.8471 - auc: 0.9458 - val_loss: 0.2469 - val_tp: 79.0000 - val_fp: 1155.0000 - val_tn: 44327.0000 - val_fn: 8.0000 - val_accuracy: 0.9745 - val_precision: 0.0640 - val_recall: 0.9080 - val_auc: 0.9759
Epoch 13/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2710 - tp: 17386.0000 - fp: 1395.0000 - tn: 19124.0000 - fn: 3055.0000 - accuracy: 0.8914 - precision: 0.9257 - recall: 0.8505 - auc: 0.9502 - val_loss: 0.2302 - val_tp: 79.0000 - val_fp: 1092.0000 - val_tn: 44390.0000 - val_fn: 8.0000 - val_accuracy: 0.9759 - val_precision: 0.0675 - val_recall: 0.9080 - val_auc: 0.9782
Epoch 14/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2618 - tp: 17336.0000 - fp: 1343.0000 - tn: 19296.0000 - fn: 2985.0000 - accuracy: 0.8943 - precision: 0.9281 - recall: 0.8531 - auc: 0.9541 - val_loss: 0.2156 - val_tp: 79.0000 - val_fp: 1053.0000 - val_tn: 44429.0000 - val_fn: 8.0000 - val_accuracy: 0.9767 - val_precision: 0.0698 - val_recall: 0.9080 - val_auc: 0.9797
Epoch 15/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2529 - tp: 17466.0000 - fp: 1154.0000 - tn: 19366.0000 - fn: 2974.0000 - accuracy: 0.8992 - precision: 0.9380 - recall: 0.8545 - auc: 0.9574 - val_loss: 0.2026 - val_tp: 79.0000 - val_fp: 1029.0000 - val_tn: 44453.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0713 - val_recall: 0.9080 - val_auc: 0.9806
Epoch 16/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2456 - tp: 17579.0000 - fp: 1075.0000 - tn: 19322.0000 - fn: 2984.0000 - accuracy: 0.9009 - precision: 0.9424 - recall: 0.8549 - auc: 0.9590 - val_loss: 0.1923 - val_tp: 79.0000 - val_fp: 1017.0000 - val_tn: 44465.0000 - val_fn: 8.0000 - val_accuracy: 0.9775 - val_precision: 0.0721 - val_recall: 0.9080 - val_auc: 0.9813
Epoch 17/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2382 - tp: 17573.0000 - fp: 982.0000 - tn: 19540.0000 - fn: 2865.0000 - accuracy: 0.9061 - precision: 0.9471 - recall: 0.8598 - auc: 0.9620 - val_loss: 0.1828 - val_tp: 79.0000 - val_fp: 1005.0000 - val_tn: 44477.0000 - val_fn: 8.0000 - val_accuracy: 0.9778 - val_precision: 0.0729 - val_recall: 0.9080 - val_auc: 0.9819
Epoch 18/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2307 - tp: 17711.0000 - fp: 966.0000 - tn: 19448.0000 - fn: 2835.0000 - accuracy: 0.9072 - precision: 0.9483 - recall: 0.8620 - auc: 0.9644 - val_loss: 0.1736 - val_tp: 80.0000 - val_fp: 990.0000 - val_tn: 44492.0000 - val_fn: 7.0000 - val_accuracy: 0.9781 - val_precision: 0.0748 - val_recall: 0.9195 - val_auc: 0.9825
Epoch 19/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2280 - tp: 17732.0000 - fp: 952.0000 - tn: 19442.0000 - fn: 2834.0000 - accuracy: 0.9076 - precision: 0.9490 - recall: 0.8622 - auc: 0.9653 - val_loss: 0.1660 - val_tp: 80.0000 - val_fp: 974.0000 - val_tn: 44508.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0759 - val_recall: 0.9195 - val_auc: 0.9826
Epoch 20/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2224 - tp: 17725.0000 - fp: 939.0000 - tn: 19538.0000 - fn: 2758.0000 - accuracy: 0.9097 - precision: 0.9497 - recall: 0.8654 - auc: 0.9667 - val_loss: 0.1591 - val_tp: 80.0000 - val_fp: 962.0000 - val_tn: 44520.0000 - val_fn: 7.0000 - val_accuracy: 0.9787 - val_precision: 0.0768 - val_recall: 0.9195 - val_auc: 0.9831
Epoch 21/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2168 - tp: 17757.0000 - fp: 826.0000 - tn: 19618.0000 - fn: 2759.0000 - accuracy: 0.9125 - precision: 0.9556 - recall: 0.8655 - auc: 0.9689 - val_loss: 0.1531 - val_tp: 80.0000 - val_fp: 967.0000 - val_tn: 44515.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0764 - val_recall: 0.9195 - val_auc: 0.9831
Epoch 22/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2112 - tp: 17833.0000 - fp: 883.0000 - tn: 19522.0000 - fn: 2722.0000 - accuracy: 0.9120 - precision: 0.9528 - recall: 0.8676 - auc: 0.9703 - val_loss: 0.1479 - val_tp: 80.0000 - val_fp: 975.0000 - val_tn: 44507.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0758 - val_recall: 0.9195 - val_auc: 0.9832
Epoch 23/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2058 - tp: 17865.0000 - fp: 835.0000 - tn: 19580.0000 - fn: 2680.0000 - accuracy: 0.9142 - precision: 0.9553 - recall: 0.8696 - auc: 0.9723 - val_loss: 0.1427 - val_tp: 80.0000 - val_fp: 977.0000 - val_tn: 44505.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0757 - val_recall: 0.9195 - val_auc: 0.9834
Epoch 24/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2053 - tp: 17856.0000 - fp: 802.0000 - tn: 19599.0000 - fn: 2703.0000 - accuracy: 0.9144 - precision: 0.9570 - recall: 0.8685 - auc: 0.9727 - val_loss: 0.1375 - val_tp: 80.0000 - val_fp: 969.0000 - val_tn: 44513.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0763 - val_recall: 0.9195 - val_auc: 0.9833
Epoch 25/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2004 - tp: 17854.0000 - fp: 809.0000 - tn: 19690.0000 - fn: 2607.0000 - accuracy: 0.9166 - precision: 0.9567 - recall: 0.8726 - auc: 0.9740 - val_loss: 0.1331 - val_tp: 80.0000 - val_fp: 976.0000 - val_tn: 44506.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0758 - val_recall: 0.9195 - val_auc: 0.9837
Epoch 26/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1991 - tp: 17857.0000 - fp: 793.0000 - tn: 19690.0000 - fn: 2620.0000 - accuracy: 0.9167 - precision: 0.9575 - recall: 0.8721 - auc: 0.9747 - val_loss: 0.1291 - val_tp: 80.0000 - val_fp: 968.0000 - val_tn: 44514.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0763 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 27/1000
20/20 [==============================] - 1s 40ms/step - loss: 0.1929 - tp: 17836.0000 - fp: 750.0000 - tn: 19833.0000 - fn: 2541.0000 - accuracy: 0.9197 - precision: 0.9596 - recall: 0.8753 - auc: 0.9760 - val_loss: 0.1252 - val_tp: 80.0000 - val_fp: 960.0000 - val_tn: 44522.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0769 - val_recall: 0.9195 - val_auc: 0.9839
Epoch 28/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.1935 - tp: 17776.0000 - fp: 753.0000 - tn: 19827.0000 - fn: 2604.0000 - accuracy: 0.9180 - precision: 0.9594 - recall: 0.8722 - auc: 0.9763 - val_loss: 0.1215 - val_tp: 80.0000 - val_fp: 946.0000 - val_tn: 44536.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0780 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 29/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1892 - tp: 17877.0000 - fp: 746.0000 - tn: 19791.0000 - fn: 2546.0000 - accuracy: 0.9196 - precision: 0.9599 - recall: 0.8753 - auc: 0.9773 - val_loss: 0.1183 - val_tp: 80.0000 - val_fp: 944.0000 - val_tn: 44538.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0781 - val_recall: 0.9195 - val_auc: 0.9840
Epoch 30/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1855 - tp: 18053.0000 - fp: 746.0000 - tn: 19673.0000 - fn: 2488.0000 - accuracy: 0.9210 - precision: 0.9603 - recall: 0.8789 - auc: 0.9779 - val_loss: 0.1157 - val_tp: 80.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 7.0000 - val_accuracy: 0.9790 - val_precision: 0.0777 - val_recall: 0.9195 - val_auc: 0.9835
Epoch 31/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1843 - tp: 18042.0000 - fp: 723.0000 - tn: 19656.0000 - fn: 2539.0000 - accuracy: 0.9204 - precision: 0.9615 - recall: 0.8766 - auc: 0.9783 - val_loss: 0.1137 - val_tp: 80.0000 - val_fp: 958.0000 - val_tn: 44524.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0771 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 32/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1831 - tp: 17974.0000 - fp: 743.0000 - tn: 19741.0000 - fn: 2502.0000 - accuracy: 0.9208 - precision: 0.9603 - recall: 0.8778 - auc: 0.9789 - val_loss: 0.1112 - val_tp: 80.0000 - val_fp: 958.0000 - val_tn: 44524.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0771 - val_recall: 0.9195 - val_auc: 0.9840
Epoch 33/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1805 - tp: 18172.0000 - fp: 775.0000 - tn: 19591.0000 - fn: 2422.0000 - accuracy: 0.9219 - precision: 0.9591 - recall: 0.8824 - auc: 0.9796 - val_loss: 0.1088 - val_tp: 81.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 6.0000 - val_accuracy: 0.9789 - val_precision: 0.0781 - val_recall: 0.9310 - val_auc: 0.9841
Epoch 34/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1749 - tp: 18125.0000 - fp: 715.0000 - tn: 19698.0000 - fn: 2422.0000 - accuracy: 0.9234 - precision: 0.9620 - recall: 0.8821 - auc: 0.9812 - val_loss: 0.1068 - val_tp: 81.0000 - val_fp: 964.0000 - val_tn: 44518.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0775 - val_recall: 0.9310 - val_auc: 0.9836
Epoch 35/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1769 - tp: 18135.0000 - fp: 715.0000 - tn: 19694.0000 - fn: 2416.0000 - accuracy: 0.9236 - precision: 0.9621 - recall: 0.8824 - auc: 0.9809 - val_loss: 0.1048 - val_tp: 81.0000 - val_fp: 978.0000 - val_tn: 44504.0000 - val_fn: 6.0000 - val_accuracy: 0.9784 - val_precision: 0.0765 - val_recall: 0.9310 - val_auc: 0.9838
Epoch 36/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1739 - tp: 18006.0000 - fp: 704.0000 - tn: 19827.0000 - fn: 2423.0000 - accuracy: 0.9237 - precision: 0.9624 - recall: 0.8814 - auc: 0.9814 - val_loss: 0.1029 - val_tp: 81.0000 - val_fp: 986.0000 - val_tn: 44496.0000 - val_fn: 6.0000 - val_accuracy: 0.9782 - val_precision: 0.0759 - val_recall: 0.9310 - val_auc: 0.9839
Epoch 37/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1687 - tp: 18002.0000 - fp: 660.0000 - tn: 19879.0000 - fn: 2419.0000 - accuracy: 0.9248 - precision: 0.9646 - recall: 0.8815 - auc: 0.9826 - val_loss: 0.1011 - val_tp: 81.0000 - val_fp: 984.0000 - val_tn: 44498.0000 - val_fn: 6.0000 - val_accuracy: 0.9783 - val_precision: 0.0761 - val_recall: 0.9310 - val_auc: 0.9841
Epoch 38/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1699 - tp: 17932.0000 - fp: 677.0000 - tn: 19986.0000 - fn: 2365.0000 - accuracy: 0.9257 - precision: 0.9636 - recall: 0.8835 - auc: 0.9825 - val_loss: 0.0995 - val_tp: 82.0000 - val_fp: 979.0000 - val_tn: 44503.0000 - val_fn: 5.0000 - val_accuracy: 0.9784 - val_precision: 0.0773 - val_recall: 0.9425 - val_auc: 0.9842
Epoch 39/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1676 - tp: 18086.0000 - fp: 736.0000 - tn: 19780.0000 - fn: 2358.0000 - accuracy: 0.9245 - precision: 0.9609 - recall: 0.8847 - auc: 0.9826 - val_loss: 0.0980 - val_tp: 82.0000 - val_fp: 975.0000 - val_tn: 44507.0000 - val_fn: 5.0000 - val_accuracy: 0.9785 - val_precision: 0.0776 - val_recall: 0.9425 - val_auc: 0.9844
Epoch 40/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1670 - tp: 18066.0000 - fp: 685.0000 - tn: 19868.0000 - fn: 2341.0000 - accuracy: 0.9261 - precision: 0.9635 - recall: 0.8853 - auc: 0.9832 - val_loss: 0.0964 - val_tp: 82.0000 - val_fp: 965.0000 - val_tn: 44517.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0783 - val_recall: 0.9425 - val_auc: 0.9845
Epoch 41/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1640 - tp: 17950.0000 - fp: 645.0000 - tn: 19995.0000 - fn: 2370.0000 - accuracy: 0.9264 - precision: 0.9653 - recall: 0.8834 - auc: 0.9839 - val_loss: 0.0950 - val_tp: 82.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0790 - val_recall: 0.9425 - val_auc: 0.9835
Epoch 42/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1641 - tp: 18083.0000 - fp: 665.0000 - tn: 19842.0000 - fn: 2370.0000 - accuracy: 0.9259 - precision: 0.9645 - recall: 0.8841 - auc: 0.9839 - val_loss: 0.0938 - val_tp: 82.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0795 - val_recall: 0.9425 - val_auc: 0.9837
Epoch 43/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1600 - tp: 18012.0000 - fp: 684.0000 - tn: 19970.0000 - fn: 2294.0000 - accuracy: 0.9273 - precision: 0.9634 - recall: 0.8870 - auc: 0.9845 - val_loss: 0.0925 - val_tp: 82.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0795 - val_recall: 0.9425 - val_auc: 0.9837
Epoch 44/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1597 - tp: 18346.0000 - fp: 657.0000 - tn: 19657.0000 - fn: 2300.0000 - accuracy: 0.9278 - precision: 0.9654 - recall: 0.8886 - auc: 0.9847 - val_loss: 0.0919 - val_tp: 82.0000 - val_fp: 955.0000 - val_tn: 44527.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0791 - val_recall: 0.9425 - val_auc: 0.9838
Epoch 45/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1607 - tp: 18109.0000 - fp: 726.0000 - tn: 19836.0000 - fn: 2289.0000 - accuracy: 0.9264 - precision: 0.9615 - recall: 0.8878 - auc: 0.9846 - val_loss: 0.0908 - val_tp: 82.0000 - val_fp: 948.0000 - val_tn: 44534.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0796 - val_recall: 0.9425 - val_auc: 0.9839
Epoch 46/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1581 - tp: 18192.0000 - fp: 650.0000 - tn: 19833.0000 - fn: 2285.0000 - accuracy: 0.9283 - precision: 0.9655 - recall: 0.8884 - auc: 0.9849 - val_loss: 0.0902 - val_tp: 82.0000 - val_fp: 955.0000 - val_tn: 44527.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0791 - val_recall: 0.9425 - val_auc: 0.9839
Epoch 47/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1579 - tp: 18301.0000 - fp: 676.0000 - tn: 19760.0000 - fn: 2223.0000 - accuracy: 0.9292 - precision: 0.9644 - recall: 0.8917 - auc: 0.9853 - val_loss: 0.0892 - val_tp: 82.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0790 - val_recall: 0.9425 - val_auc: 0.9840
Epoch 48/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1503 - tp: 18172.0000 - fp: 593.0000 - tn: 19959.0000 - fn: 2236.0000 - accuracy: 0.9309 - precision: 0.9684 - recall: 0.8904 - auc: 0.9867 - val_loss: 0.0887 - val_tp: 82.0000 - val_fp: 970.0000 - val_tn: 44512.0000 - val_fn: 5.0000 - val_accuracy: 0.9786 - val_precision: 0.0779 - val_recall: 0.9425 - val_auc: 0.9840
Epoch 49/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1572 - tp: 18217.0000 - fp: 750.0000 - tn: 19709.0000 - fn: 2284.0000 - accuracy: 0.9259 - precision: 0.9605 - recall: 0.8886 - auc: 0.9852 - val_loss: 0.0876 - val_tp: 82.0000 - val_fp: 964.0000 - val_tn: 44518.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0784 - val_recall: 0.9425 - val_auc: 0.9841
Epoch 50/1000
20/20 [==============================] - ETA: 0s - loss: 0.1529 - tp: 18230.0000 - fp: 696.0000 - tn: 19874.0000 - fn: 2160.0000 - accuracy: 0.9303 - precision: 0.9632 - recall: 0.8941 - auc: 0.9860Restoring model weights from the end of the best epoch.
20/20 [==============================] - 0s 23ms/step - loss: 0.1529 - tp: 18230.0000 - fp: 696.0000 - tn: 19874.0000 - fn: 2160.0000 - accuracy: 0.9303 - precision: 0.9632 - recall: 0.8941 - auc: 0.9860 - val_loss: 0.0860 - val_tp: 82.0000 - val_fp: 941.0000 - val_tn: 44541.0000 - val_fn: 5.0000 - val_accuracy: 0.9792 - val_precision: 0.0802 - val_recall: 0.9425 - val_auc: 0.9843
Epoch 00050: early stopping

Vuelva a verificar el historial de entrenamiento

 plot_metrics(resampled_history)
 

png

Evaluar métricas

 train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
 
 resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
 
loss :  0.09607589244842529
tp :  84.0
fp :  1195.0
tn :  55676.0
fn :  7.0
accuracy :  0.9788982272148132
precision :  0.06567630916833878
recall :  0.9230769276618958
auc :  0.9697299599647522

Legitimate Transactions Detected (True Negatives):  55676
Legitimate Transactions Incorrectly Detected (False Positives):  1195
Fraudulent Transactions Missed (False Negatives):  7
Fraudulent Transactions Detected (True Positives):  84
Total Fraudulent Transactions:  91

png

Trazar el ROC

 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled,  color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled,  color=colors[2], linestyle='--')
plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa4bc66c9b0>

png

Aplicando este tutorial a tu problema

La clasificación de datos desequilibrados es una tarea inherentemente difícil ya que hay muy pocas muestras de las cuales aprender. Siempre debe comenzar con los datos primero y hacer todo lo posible para recopilar tantas muestras como sea posible y reflexionar sobre qué características pueden ser relevantes para que el modelo pueda aprovechar al máximo su clase minoritaria. En algún momento, su modelo puede tener dificultades para mejorar y obtener los resultados que desea, por lo que es importante tener en cuenta el contexto de su problema y las compensaciones entre los diferentes tipos de errores.