Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Clasificación sobre datos desequilibrados

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Este tutorial demuestra cómo clasificar un conjunto de datos altamente desequilibrado en el que la cantidad de ejemplos en una clase supera en gran medida a los ejemplos en otra. Va a trabajar con la tarjeta de crédito de Detección de Fraude conjunto de datos alojada en Kaggle. El objetivo es detectar tan solo 492 transacciones fraudulentas de un total de 284.807 transacciones. Que va a utilizar Keras para definir el modelo y los pesos de clase para ayudar a aprender el modelo a partir de los datos desequilibradas. .

Este tutorial contiene código completo para:

  • Cargue un archivo CSV usando Pandas.
  • Cree conjuntos de entrenamiento, validación y prueba.
  • Defina y entrene un modelo usando Keras (incluido el establecimiento de pesos de clase).
  • Evalúe el modelo utilizando varias métricas (incluida la precisión y la recuperación).
  • Pruebe técnicas comunes para lidiar con datos desequilibrados como:
    • Ponderación de clase
    • Sobremuestreo

Configuración

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

Procesamiento y exploración de datos

Descargue el conjunto de datos de fraude de tarjetas de crédito de Kaggle

Pandas es una biblioteca de Python con muchas utilidades útiles para cargar y trabajar con datos estructurados. Se puede utilizar para descargar CSV en una de las pandas trama de datos .

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

Examinar el desequilibrio de la etiqueta de clase

Veamos el desequilibrio del conjunto de datos:

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)

Esto muestra la pequeña fracción de muestras positivas.

Limpiar, dividir y normalizar los datos

Los datos sin procesar tienen algunos problemas. En primer lugar los Time y Amount columnas son demasiado variables para ser utilizado directamente. La caída de la Time columna (ya que no está claro lo que significa) y tomar el logaritmo de la Amount de columna para reducir su gama.

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)

Divida el conjunto de datos en conjuntos de prueba, validación y entrenamiento. El conjunto de validación se utiliza durante el ajuste del modelo para evaluar la pérdida y cualquier métrica, sin embargo, el modelo no se ajusta a estos datos. El conjunto de prueba no se usa por completo durante la fase de entrenamiento y solo se usa al final para evaluar qué tan bien se generaliza el modelo a nuevos datos. Esto es especialmente importante con los conjuntos de datos desequilibradas donde sobreajuste es una preocupación significativa por la falta de datos de entrenamiento.

# Use a utility from sklearn to split and shuffle your dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

Normalice las funciones de entrada con sklearn StandardScaler. Esto establecerá la media en 0 y la desviación estándar en 1.

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

Mira la distribución de datos

A continuación, compare las distribuciones de los ejemplos positivos y negativos de algunas características. Buenas preguntas que debe hacerse en este punto son:

  • ¿Tienen sentido estas distribuciones?
    • Si. Usted ha normalizado la entrada y éstas se concentran sobre todo en el +/- 2 gama.
  • ¿Puedes ver la diferencia entre las distribuciones?
    • Sí, los ejemplos positivos contienen una tasa mucho mayor de valores extremos.
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/kbuilder/.local/lib/python3.7/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning
/home/kbuilder/.local/lib/python3.7/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning

png

png

Definir el modelo y las métricas

Definir una función que crea una red neuronal simple con una capa densamente conectado oculta, una deserción capa para reducir overfitting, y una capa sigmoide de salida que devuelve la probabilidad de una transacción fraudulenta ser:

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
      keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

def make_model(metrics=METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(learning_rate=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

Comprender métricas útiles

Tenga en cuenta que hay algunas métricas definidas anteriormente que pueden ser calculadas por el modelo y que serán útiles al evaluar el desempeño.

  • Los falsos negativos y falsos positivos son muestras que fueron clasificadas incorrectamente
  • Verdaderos negativos y positivos verdaderos son muestras que fueron clasificados correctamente
  • La precisión es el porcentaje de ejemplos clasificó correctamente> \(\frac{\text{true samples} }{\text{total samples} }\)
  • La precisión es el porcentaje de positivos predichos que se clasifican correctamente> \(\frac{\text{true positives} }{\text{true positives + false positives} }\)
  • Recall es el porcentaje de positivos reales que fueron correctamente clasificados> \(\frac{\text{true positives} }{\text{true positives + false negatives} }\)
  • AUC se refiere al área bajo la curva de una curva característica de funcionamiento del receptor (ROC-AUC). Esta métrica es igual a la probabilidad de que un clasificador clasifique una muestra aleatoria positiva más alta que una muestra aleatoria negativa.
  • AUPRC se refiere al área bajo la curva de la curva de precisión de recordar. Esta métrica calcula pares de recuperación de precisión para diferentes umbrales de probabilidad.

Lee mas:

Modelo de línea de base

Construye el modelo

Ahora cree y entrene su modelo usando la función que se definió anteriormente. Tenga en cuenta que el modelo se ajusta con un tamaño de lote mayor que el predeterminado de 2048, esto es importante para garantizar que cada lote tenga una probabilidad decente de contener algunas muestras positivas. Si el tamaño del lote fuera demasiado pequeño, probablemente no tendrían transacciones fraudulentas de las que aprender.

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_prc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________

Prueba de ejecución del modelo:

model.predict(train_features[:10])
array([[0.7679975 ],
       [0.814587  ],
       [0.8609543 ],
       [0.7803842 ],
       [0.78393793],
       [0.8893498 ],
       [0.8417414 ],
       [0.82623696],
       [0.5085947 ],
       [0.79425156]], dtype=float32)

Opcional: establezca el sesgo inicial correcto.

Estas suposiciones iniciales no son buenas. Sabes que el conjunto de datos está desequilibrado. Establecer el sesgo de la capa de salida para reflejar que (Ver: Una receta para el Entrenamiento de Redes Neuronales: "init bien" ). Esto puede ayudar con la convergencia inicial.

Con la polarización predeterminada de inicialización la pérdida debe ser aproximadamente math.log(2) = 0.69314

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 1.7370

El sesgo correcto para establecer se puede derivar de:

\[ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) \]

\[ b_0 = -log_e(1/p_0 - 1) \]

\[ b_0 = log_e(pos/neg)\]

initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

Establezca eso como el sesgo inicial, y el modelo dará conjeturas iniciales mucho más razonables.

Debe estar cerca: pos/total = 0.0018

model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
array([[0.00629375],
       [0.00208666],
       [0.00152029],
       [0.00087535],
       [0.0018215 ],
       [0.00622139],
       [0.0009076 ],
       [0.00113686],
       [0.00362762],
       [0.00282635]], dtype=float32)

Con esta inicialización, la pérdida inicial debería ser aproximadamente:

\[-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317\]

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0122

Esta pérdida inicial es aproximadamente 50 veces menor que si hubiera sido con una inicialización ingenua.

De esta manera, el modelo no necesita pasar las primeras épocas simplemente aprendiendo que los ejemplos positivos son poco probables. Esto también facilita la lectura de gráficos de la pérdida durante el entrenamiento.

Punto de control de los pesos iniciales

Para que las distintas ejecuciones de entrenamiento sean más comparables, mantenga los pesos de este modelo inicial en un archivo de punto de control y cárguelos en cada modelo antes del entrenamiento:

initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)

Confirme que la corrección de sesgo ayuda

Antes de continuar, confirme rápidamente que la inicialización cuidadosa del sesgo realmente ayudó.

Entrene el modelo durante 20 épocas, con y sin esta inicialización cuidadosa, y compare las pérdidas:

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale on y-axis to show the wide range of values.
  plt.semilogy(history.epoch, history.history['loss'],
               color=colors[n], label='Train ' + label)
  plt.semilogy(history.epoch, history.history['val_loss'],
               color=colors[n], label='Val ' + label,
               linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

png

La figura anterior lo deja claro: en términos de pérdida de validación, en este problema, esta inicialización cuidadosa da una clara ventaja.

Entrena el modelo

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 3s 14ms/step - loss: 0.0106 - tp: 91.0000 - fp: 47.0000 - tn: 227418.0000 - fn: 289.0000 - accuracy: 0.9985 - precision: 0.6594 - recall: 0.2395 - auc: 0.7959 - prc: 0.2832 - val_loss: 0.0062 - val_tp: 7.0000 - val_fp: 8.0000 - val_tn: 45488.0000 - val_fn: 66.0000 - val_accuracy: 0.9984 - val_precision: 0.4667 - val_recall: 0.0959 - val_auc: 0.9033 - val_prc: 0.5552
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0072 - tp: 92.0000 - fp: 23.0000 - tn: 181946.0000 - fn: 215.0000 - accuracy: 0.9987 - precision: 0.8000 - recall: 0.2997 - auc: 0.8687 - prc: 0.4823 - val_loss: 0.0048 - val_tp: 33.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 40.0000 - val_accuracy: 0.9989 - val_precision: 0.7674 - val_recall: 0.4521 - val_auc: 0.9380 - val_prc: 0.6383
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0059 - tp: 150.0000 - fp: 26.0000 - tn: 181943.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8523 - recall: 0.4886 - auc: 0.8907 - prc: 0.6127 - val_loss: 0.0044 - val_tp: 37.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 36.0000 - val_accuracy: 0.9990 - val_precision: 0.7872 - val_recall: 0.5068 - val_auc: 0.9381 - val_prc: 0.6596
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0055 - tp: 150.0000 - fp: 24.0000 - tn: 181945.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8621 - recall: 0.4886 - auc: 0.9035 - prc: 0.6322 - val_loss: 0.0041 - val_tp: 44.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 29.0000 - val_accuracy: 0.9991 - val_precision: 0.8000 - val_recall: 0.6027 - val_auc: 0.9381 - val_prc: 0.6903
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0054 - tp: 165.0000 - fp: 30.0000 - tn: 181939.0000 - fn: 142.0000 - accuracy: 0.9991 - precision: 0.8462 - recall: 0.5375 - auc: 0.9138 - prc: 0.6374 - val_loss: 0.0039 - val_tp: 44.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 29.0000 - val_accuracy: 0.9991 - val_precision: 0.8148 - val_recall: 0.6027 - val_auc: 0.9381 - val_prc: 0.7000
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 167.0000 - fp: 27.0000 - tn: 181942.0000 - fn: 140.0000 - accuracy: 0.9991 - precision: 0.8608 - recall: 0.5440 - auc: 0.9106 - prc: 0.6583 - val_loss: 0.0037 - val_tp: 47.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 26.0000 - val_accuracy: 0.9992 - val_precision: 0.8103 - val_recall: 0.6438 - val_auc: 0.9449 - val_prc: 0.7127
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 168.0000 - fp: 28.0000 - tn: 181941.0000 - fn: 139.0000 - accuracy: 0.9991 - precision: 0.8571 - recall: 0.5472 - auc: 0.9107 - prc: 0.6644 - val_loss: 0.0036 - val_tp: 48.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 25.0000 - val_accuracy: 0.9992 - val_precision: 0.8136 - val_recall: 0.6575 - val_auc: 0.9449 - val_prc: 0.7384
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 164.0000 - fp: 26.0000 - tn: 181943.0000 - fn: 143.0000 - accuracy: 0.9991 - precision: 0.8632 - recall: 0.5342 - auc: 0.9143 - prc: 0.6922 - val_loss: 0.0035 - val_tp: 50.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 23.0000 - val_accuracy: 0.9993 - val_precision: 0.8197 - val_recall: 0.6849 - val_auc: 0.9449 - val_prc: 0.7377
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 168.0000 - fp: 28.0000 - tn: 181941.0000 - fn: 139.0000 - accuracy: 0.9991 - precision: 0.8571 - recall: 0.5472 - auc: 0.9159 - prc: 0.6942 - val_loss: 0.0035 - val_tp: 56.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 17.0000 - val_accuracy: 0.9994 - val_precision: 0.8358 - val_recall: 0.7671 - val_auc: 0.9449 - val_prc: 0.7251
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 183.0000 - fp: 36.0000 - tn: 181933.0000 - fn: 124.0000 - accuracy: 0.9991 - precision: 0.8356 - recall: 0.5961 - auc: 0.9144 - prc: 0.6821 - val_loss: 0.0033 - val_tp: 51.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8226 - val_recall: 0.6986 - val_auc: 0.9449 - val_prc: 0.7553
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 193.0000 - fp: 31.0000 - tn: 181938.0000 - fn: 114.0000 - accuracy: 0.9992 - precision: 0.8616 - recall: 0.6287 - auc: 0.9110 - prc: 0.6922 - val_loss: 0.0032 - val_tp: 49.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 24.0000 - val_accuracy: 0.9992 - val_precision: 0.8167 - val_recall: 0.6712 - val_auc: 0.9449 - val_prc: 0.7665
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 169.0000 - fp: 37.0000 - tn: 181932.0000 - fn: 138.0000 - accuracy: 0.9990 - precision: 0.8204 - recall: 0.5505 - auc: 0.9079 - prc: 0.6502 - val_loss: 0.0031 - val_tp: 47.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 26.0000 - val_accuracy: 0.9992 - val_precision: 0.8103 - val_recall: 0.6438 - val_auc: 0.9449 - val_prc: 0.7827
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 179.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 128.0000 - accuracy: 0.9991 - precision: 0.8606 - recall: 0.5831 - auc: 0.9159 - prc: 0.6820 - val_loss: 0.0030 - val_tp: 51.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8226 - val_recall: 0.6986 - val_auc: 0.9449 - val_prc: 0.7799
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 175.0000 - fp: 31.0000 - tn: 181938.0000 - fn: 132.0000 - accuracy: 0.9991 - precision: 0.8495 - recall: 0.5700 - auc: 0.9079 - prc: 0.6707 - val_loss: 0.0030 - val_tp: 53.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8281 - val_recall: 0.7260 - val_auc: 0.9450 - val_prc: 0.7857
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 179.0000 - fp: 28.0000 - tn: 181941.0000 - fn: 128.0000 - accuracy: 0.9991 - precision: 0.8647 - recall: 0.5831 - auc: 0.9161 - prc: 0.7002 - val_loss: 0.0030 - val_tp: 55.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8333 - val_recall: 0.7534 - val_auc: 0.9449 - val_prc: 0.7878
Epoch 16/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 178.0000 - fp: 30.0000 - tn: 181939.0000 - fn: 129.0000 - accuracy: 0.9991 - precision: 0.8558 - recall: 0.5798 - auc: 0.9160 - prc: 0.6741 - val_loss: 0.0029 - val_tp: 54.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8308 - val_recall: 0.7397 - val_auc: 0.9450 - val_prc: 0.7935
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 181.0000 - fp: 32.0000 - tn: 181937.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8498 - recall: 0.5896 - auc: 0.9112 - prc: 0.6992 - val_loss: 0.0029 - val_tp: 54.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8308 - val_recall: 0.7397 - val_auc: 0.9450 - val_prc: 0.7950
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 178.0000 - fp: 23.0000 - tn: 181946.0000 - fn: 129.0000 - accuracy: 0.9992 - precision: 0.8856 - recall: 0.5798 - auc: 0.9177 - prc: 0.6889 - val_loss: 0.0028 - val_tp: 54.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8438 - val_recall: 0.7397 - val_auc: 0.9450 - val_prc: 0.8022
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 182.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 125.0000 - accuracy: 0.9992 - precision: 0.8626 - recall: 0.5928 - auc: 0.9111 - prc: 0.6901 - val_loss: 0.0028 - val_tp: 54.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8308 - val_recall: 0.7397 - val_auc: 0.9450 - val_prc: 0.8031
Epoch 20/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 177.0000 - fp: 33.0000 - tn: 181936.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8429 - recall: 0.5765 - auc: 0.9209 - prc: 0.6837 - val_loss: 0.0027 - val_tp: 54.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8438 - val_recall: 0.7397 - val_auc: 0.9518 - val_prc: 0.8137
Epoch 21/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 184.0000 - fp: 33.0000 - tn: 181936.0000 - fn: 123.0000 - accuracy: 0.9991 - precision: 0.8479 - recall: 0.5993 - auc: 0.9241 - prc: 0.6907 - val_loss: 0.0027 - val_tp: 55.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8462 - val_recall: 0.7534 - val_auc: 0.9518 - val_prc: 0.8144
Epoch 22/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 177.0000 - fp: 32.0000 - tn: 181937.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8469 - recall: 0.5765 - auc: 0.9176 - prc: 0.6630 - val_loss: 0.0027 - val_tp: 54.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8438 - val_recall: 0.7397 - val_auc: 0.9518 - val_prc: 0.8202
Epoch 23/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 187.0000 - fp: 23.0000 - tn: 181946.0000 - fn: 120.0000 - accuracy: 0.9992 - precision: 0.8905 - recall: 0.6091 - auc: 0.9339 - prc: 0.7379 - val_loss: 0.0026 - val_tp: 55.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8462 - val_recall: 0.7534 - val_auc: 0.9518 - val_prc: 0.8255
Epoch 24/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 182.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 125.0000 - accuracy: 0.9992 - precision: 0.8626 - recall: 0.5928 - auc: 0.9209 - prc: 0.7163 - val_loss: 0.0026 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.8529 - val_recall: 0.7945 - val_auc: 0.9518 - val_prc: 0.8223
Epoch 25/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 189.0000 - fp: 32.0000 - tn: 181937.0000 - fn: 118.0000 - accuracy: 0.9992 - precision: 0.8552 - recall: 0.6156 - auc: 0.9225 - prc: 0.7105 - val_loss: 0.0026 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 15.0000 - val_accuracy: 0.9994 - val_precision: 0.8406 - val_recall: 0.7945 - val_auc: 0.9518 - val_prc: 0.8235
Epoch 26/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 189.0000 - fp: 32.0000 - tn: 181937.0000 - fn: 118.0000 - accuracy: 0.9992 - precision: 0.8552 - recall: 0.6156 - auc: 0.9226 - prc: 0.7259 - val_loss: 0.0026 - val_tp: 55.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8462 - val_recall: 0.7534 - val_auc: 0.9518 - val_prc: 0.8276
Epoch 27/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 183.0000 - fp: 25.0000 - tn: 181944.0000 - fn: 124.0000 - accuracy: 0.9992 - precision: 0.8798 - recall: 0.5961 - auc: 0.9193 - prc: 0.7203 - val_loss: 0.0026 - val_tp: 57.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 16.0000 - val_accuracy: 0.9994 - val_precision: 0.8507 - val_recall: 0.7808 - val_auc: 0.9518 - val_prc: 0.8291
Epoch 28/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 185.0000 - fp: 31.0000 - tn: 181938.0000 - fn: 122.0000 - accuracy: 0.9992 - precision: 0.8565 - recall: 0.6026 - auc: 0.9290 - prc: 0.7219 - val_loss: 0.0025 - val_tp: 57.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 16.0000 - val_accuracy: 0.9994 - val_precision: 0.8507 - val_recall: 0.7808 - val_auc: 0.9518 - val_prc: 0.8318
Epoch 29/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 184.0000 - fp: 28.0000 - tn: 181941.0000 - fn: 123.0000 - accuracy: 0.9992 - precision: 0.8679 - recall: 0.5993 - auc: 0.9356 - prc: 0.7277 - val_loss: 0.0027 - val_tp: 59.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8551 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8285
Epoch 30/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 183.0000 - fp: 27.0000 - tn: 181942.0000 - fn: 124.0000 - accuracy: 0.9992 - precision: 0.8714 - recall: 0.5961 - auc: 0.9258 - prc: 0.7126 - val_loss: 0.0026 - val_tp: 59.0000 - val_fp: 11.0000 - val_tn: 45485.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8429 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8263
Epoch 31/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 191.0000 - fp: 27.0000 - tn: 181942.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8761 - recall: 0.6221 - auc: 0.9242 - prc: 0.7246 - val_loss: 0.0026 - val_tp: 59.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8551 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8300
Epoch 32/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 180.0000 - fp: 28.0000 - tn: 181941.0000 - fn: 127.0000 - accuracy: 0.9991 - precision: 0.8654 - recall: 0.5863 - auc: 0.9226 - prc: 0.7279 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8551 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8324
Epoch 33/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 183.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 124.0000 - accuracy: 0.9992 - precision: 0.8632 - recall: 0.5961 - auc: 0.9191 - prc: 0.7035 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8551 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8357
Epoch 34/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 191.0000 - fp: 31.0000 - tn: 181938.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8604 - recall: 0.6221 - auc: 0.9306 - prc: 0.7164 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8551 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8351
Epoch 35/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 187.0000 - fp: 30.0000 - tn: 181939.0000 - fn: 120.0000 - accuracy: 0.9992 - precision: 0.8618 - recall: 0.6091 - auc: 0.9307 - prc: 0.7257 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 7.0000 - val_tn: 45489.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8939 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8380
Epoch 36/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 184.0000 - fp: 28.0000 - tn: 181941.0000 - fn: 123.0000 - accuracy: 0.9992 - precision: 0.8679 - recall: 0.5993 - auc: 0.9159 - prc: 0.7001 - val_loss: 0.0025 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9062 - val_recall: 0.7945 - val_auc: 0.9518 - val_prc: 0.8378
Epoch 37/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0037 - tp: 195.0000 - fp: 28.0000 - tn: 181941.0000 - fn: 112.0000 - accuracy: 0.9992 - precision: 0.8744 - recall: 0.6352 - auc: 0.9388 - prc: 0.7372 - val_loss: 0.0025 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9016 - val_recall: 0.7534 - val_auc: 0.9518 - val_prc: 0.8402
Epoch 38/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 182.0000 - fp: 31.0000 - tn: 181938.0000 - fn: 125.0000 - accuracy: 0.9991 - precision: 0.8545 - recall: 0.5928 - auc: 0.9192 - prc: 0.7087 - val_loss: 0.0025 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9062 - val_recall: 0.7945 - val_auc: 0.9518 - val_prc: 0.8375
Epoch 39/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 175.0000 - fp: 26.0000 - tn: 181943.0000 - fn: 132.0000 - accuracy: 0.9991 - precision: 0.8706 - recall: 0.5700 - auc: 0.9274 - prc: 0.7251 - val_loss: 0.0025 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9062 - val_recall: 0.7945 - val_auc: 0.9518 - val_prc: 0.8369
Epoch 40/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0037 - tp: 187.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 120.0000 - accuracy: 0.9992 - precision: 0.8657 - recall: 0.6091 - auc: 0.9372 - prc: 0.7372 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8551 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8364
Epoch 41/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 195.0000 - fp: 27.0000 - tn: 181942.0000 - fn: 112.0000 - accuracy: 0.9992 - precision: 0.8784 - recall: 0.6352 - auc: 0.9291 - prc: 0.7347 - val_loss: 0.0025 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9062 - val_recall: 0.7945 - val_auc: 0.9518 - val_prc: 0.8406
Epoch 42/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 183.0000 - fp: 30.0000 - tn: 181939.0000 - fn: 124.0000 - accuracy: 0.9992 - precision: 0.8592 - recall: 0.5961 - auc: 0.9355 - prc: 0.7198 - val_loss: 0.0025 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9062 - val_recall: 0.7945 - val_auc: 0.9518 - val_prc: 0.8421
Epoch 43/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 186.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 121.0000 - accuracy: 0.9992 - precision: 0.8651 - recall: 0.6059 - auc: 0.9193 - prc: 0.7033 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.9077 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8402
Epoch 44/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0036 - tp: 200.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 107.0000 - accuracy: 0.9993 - precision: 0.8734 - recall: 0.6515 - auc: 0.9339 - prc: 0.7457 - val_loss: 0.0025 - val_tp: 57.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9048 - val_recall: 0.7808 - val_auc: 0.9518 - val_prc: 0.8418
Epoch 45/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 180.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 127.0000 - accuracy: 0.9991 - precision: 0.8612 - recall: 0.5863 - auc: 0.9241 - prc: 0.7130 - val_loss: 0.0025 - val_tp: 56.0000 - val_fp: 5.0000 - val_tn: 45491.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9180 - val_recall: 0.7671 - val_auc: 0.9518 - val_prc: 0.8411
Epoch 46/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0037 - tp: 174.0000 - fp: 33.0000 - tn: 181936.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8406 - recall: 0.5668 - auc: 0.9373 - prc: 0.7416 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.9077 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8397
Epoch 47/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0036 - tp: 196.0000 - fp: 31.0000 - tn: 181938.0000 - fn: 111.0000 - accuracy: 0.9992 - precision: 0.8634 - recall: 0.6384 - auc: 0.9308 - prc: 0.7446 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.9077 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8397
Epoch 48/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0037 - tp: 187.0000 - fp: 32.0000 - tn: 181937.0000 - fn: 120.0000 - accuracy: 0.9992 - precision: 0.8539 - recall: 0.6091 - auc: 0.9357 - prc: 0.7426 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 7.0000 - val_tn: 45489.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8939 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8350
Epoch 49/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0036 - tp: 201.0000 - fp: 28.0000 - tn: 181941.0000 - fn: 106.0000 - accuracy: 0.9993 - precision: 0.8777 - recall: 0.6547 - auc: 0.9292 - prc: 0.7458 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8551 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8342
Epoch 50/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0037 - tp: 198.0000 - fp: 27.0000 - tn: 181942.0000 - fn: 109.0000 - accuracy: 0.9993 - precision: 0.8800 - recall: 0.6450 - auc: 0.9177 - prc: 0.7286 - val_loss: 0.0025 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45490.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9062 - val_recall: 0.7945 - val_auc: 0.9518 - val_prc: 0.8370
Epoch 51/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0037 - tp: 176.0000 - fp: 29.0000 - tn: 181940.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8585 - recall: 0.5733 - auc: 0.9307 - prc: 0.7339 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 10.0000 - val_tn: 45486.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8551 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8337
Epoch 52/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0035 - tp: 198.0000 - fp: 26.0000 - tn: 181943.0000 - fn: 109.0000 - accuracy: 0.9993 - precision: 0.8839 - recall: 0.6450 - auc: 0.9406 - prc: 0.7548 - val_loss: 0.0025 - val_tp: 59.0000 - val_fp: 8.0000 - val_tn: 45488.0000 - val_fn: 14.0000 - val_accuracy: 0.9995 - val_precision: 0.8806 - val_recall: 0.8082 - val_auc: 0.9518 - val_prc: 0.8402
Restoring model weights from the end of the best epoch.
Epoch 00052: early stopping

Consultar historial de entrenamiento

En esta sección, producirá gráficos de la precisión y la pérdida de su modelo en el conjunto de entrenamiento y validación. Estos son útiles para comprobar si hay un ajuste por exceso, que se puede aprender más acerca de la overfit y underfit tutorial.

Además, puede producir estos gráficos para cualquiera de las métricas que creó anteriormente. Los falsos negativos se incluyen como ejemplo.

def plot_metrics(history):
  metrics = ['loss', 'prc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

png

Evaluar métricas

Se puede utilizar una matriz de confusión de resumir etiquetas predijo la vs. real, donde el eje X es la etiqueta predicho y el eje Y es la etiqueta real:

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

Evalúe su modelo en el conjunto de datos de prueba y muestre los resultados de las métricas que creó anteriormente:

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.0028306366875767708
tp :  85.0
fp :  5.0
tn :  56845.0
fn :  27.0
accuracy :  0.9994382262229919
precision :  0.9444444179534912
recall :  0.7589285969734192
auc :  0.941765546798706
prc :  0.8580502271652222

Legitimate Transactions Detected (True Negatives):  56845
Legitimate Transactions Incorrectly Detected (False Positives):  5
Fraudulent Transactions Missed (False Negatives):  27
Fraudulent Transactions Detected (True Positives):  85
Total Fraudulent Transactions:  112

png

Si el modelo había predicho todo a la perfección, esto sería una matriz diagonal donde los valores fuera de la diagonal principal, lo que indica predicciones incorrectas, sería cero. En este caso, la matriz muestra que tiene relativamente pocos falsos positivos, lo que significa que hubo relativamente pocas transacciones legítimas que se marcaron incorrectamente. Sin embargo, es probable que desee tener incluso menos falsos negativos a pesar del costo de aumentar el número de falsos positivos. Esta compensación puede ser preferible porque los falsos negativos permitirían que se realicen transacciones fraudulentas, mientras que los falsos positivos pueden hacer que se envíe un correo electrónico a un cliente para pedirle que verifique la actividad de su tarjeta.

Trazar la República de China

Ahora trazar la República de China . Este gráfico es útil porque muestra, de un vistazo, el rango de rendimiento que el modelo puede alcanzar simplemente ajustando el umbral de salida.

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f8fa43c9fd0>

png

Trazar el AUPRC

Ahora trazar la AUPRC . Área bajo la curva de recuperación de precisión interpolada, obtenida trazando puntos (recuperación, precisión) para diferentes valores del umbral de clasificación. Dependiendo de cómo se calcule, PR AUC puede ser equivalente a la precisión promedio del modelo.

def plot_prc(name, labels, predictions, **kwargs):
    precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)

    plt.plot(precision, recall, label=name, linewidth=2, **kwargs)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f8fa4286fd0>

png

Parece que la precisión es relativamente alta, pero la recuperación y el área bajo la curva ROC (AUC) no son tan altas como le gustaría. Los clasificadores a menudo enfrentan desafíos cuando intentan maximizar tanto la precisión como la recuperación, lo cual es especialmente cierto cuando se trabaja con conjuntos de datos desequilibrados. Es importante considerar los costos de los diferentes tipos de errores en el contexto del problema que le preocupa. En este ejemplo, un falso negativo (se pierde una transacción fraudulenta) puede tener un costo financiero, mientras que un falso positivo (una transacción se marca incorrectamente como fraudulenta) puede disminuir la felicidad del usuario.

Pesos de clase

Calcular pesos de clase

El objetivo es identificar transacciones fraudulentas, pero no tiene muchas de esas muestras positivas con las que trabajar, por lo que le conviene que el clasificador tenga en cuenta los pocos ejemplos disponibles. Puede hacer esto pasando pesos de Keras para cada clase a través de un parámetro. Esto hará que el modelo "preste más atención" a los ejemplos de una clase subrepresentada.

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

Entrena un modelo con pesos de clase

Ahora intente volver a entrenar y evaluar el modelo con ponderaciones de clase para ver cómo afecta eso a las predicciones.

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 3s 14ms/step - loss: 1.6596 - tp: 157.0000 - fp: 193.0000 - tn: 238626.0000 - fn: 262.0000 - accuracy: 0.9981 - precision: 0.4486 - recall: 0.3747 - auc: 0.8475 - prc: 0.3181 - val_loss: 0.0078 - val_tp: 35.0000 - val_fp: 14.0000 - val_tn: 45482.0000 - val_fn: 38.0000 - val_accuracy: 0.9989 - val_precision: 0.7143 - val_recall: 0.4795 - val_auc: 0.9451 - val_prc: 0.4722
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.9620 - tp: 158.0000 - fp: 515.0000 - tn: 181454.0000 - fn: 149.0000 - accuracy: 0.9964 - precision: 0.2348 - recall: 0.5147 - auc: 0.8769 - prc: 0.3328 - val_loss: 0.0106 - val_tp: 54.0000 - val_fp: 28.0000 - val_tn: 45468.0000 - val_fn: 19.0000 - val_accuracy: 0.9990 - val_precision: 0.6585 - val_recall: 0.7397 - val_auc: 0.9476 - val_prc: 0.6104
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6280 - tp: 206.0000 - fp: 927.0000 - tn: 181042.0000 - fn: 101.0000 - accuracy: 0.9944 - precision: 0.1818 - recall: 0.6710 - auc: 0.9217 - prc: 0.3892 - val_loss: 0.0152 - val_tp: 61.0000 - val_fp: 54.0000 - val_tn: 45442.0000 - val_fn: 12.0000 - val_accuracy: 0.9986 - val_precision: 0.5304 - val_recall: 0.8356 - val_auc: 0.9569 - val_prc: 0.6720
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5217 - tp: 221.0000 - fp: 1443.0000 - tn: 180526.0000 - fn: 86.0000 - accuracy: 0.9916 - precision: 0.1328 - recall: 0.7199 - auc: 0.9268 - prc: 0.4364 - val_loss: 0.0206 - val_tp: 63.0000 - val_fp: 98.0000 - val_tn: 45398.0000 - val_fn: 10.0000 - val_accuracy: 0.9976 - val_precision: 0.3913 - val_recall: 0.8630 - val_auc: 0.9654 - val_prc: 0.6968
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4969 - tp: 228.0000 - fp: 2176.0000 - tn: 179793.0000 - fn: 79.0000 - accuracy: 0.9876 - precision: 0.0948 - recall: 0.7427 - auc: 0.9222 - prc: 0.3776 - val_loss: 0.0279 - val_tp: 64.0000 - val_fp: 202.0000 - val_tn: 45294.0000 - val_fn: 9.0000 - val_accuracy: 0.9954 - val_precision: 0.2406 - val_recall: 0.8767 - val_auc: 0.9621 - val_prc: 0.6914
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4541 - tp: 240.0000 - fp: 2956.0000 - tn: 179013.0000 - fn: 67.0000 - accuracy: 0.9834 - precision: 0.0751 - recall: 0.7818 - auc: 0.9232 - prc: 0.3221 - val_loss: 0.0359 - val_tp: 65.0000 - val_fp: 340.0000 - val_tn: 45156.0000 - val_fn: 8.0000 - val_accuracy: 0.9924 - val_precision: 0.1605 - val_recall: 0.8904 - val_auc: 0.9682 - val_prc: 0.6937
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3891 - tp: 248.0000 - fp: 3702.0000 - tn: 178267.0000 - fn: 59.0000 - accuracy: 0.9794 - precision: 0.0628 - recall: 0.8078 - auc: 0.9330 - prc: 0.2914 - val_loss: 0.0445 - val_tp: 65.0000 - val_fp: 480.0000 - val_tn: 45016.0000 - val_fn: 8.0000 - val_accuracy: 0.9893 - val_precision: 0.1193 - val_recall: 0.8904 - val_auc: 0.9730 - val_prc: 0.6750
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3713 - tp: 250.0000 - fp: 4448.0000 - tn: 177521.0000 - fn: 57.0000 - accuracy: 0.9753 - precision: 0.0532 - recall: 0.8143 - auc: 0.9376 - prc: 0.2792 - val_loss: 0.0525 - val_tp: 65.0000 - val_fp: 574.0000 - val_tn: 44922.0000 - val_fn: 8.0000 - val_accuracy: 0.9872 - val_precision: 0.1017 - val_recall: 0.8904 - val_auc: 0.9736 - val_prc: 0.6481
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3426 - tp: 256.0000 - fp: 4998.0000 - tn: 176971.0000 - fn: 51.0000 - accuracy: 0.9723 - precision: 0.0487 - recall: 0.8339 - auc: 0.9410 - prc: 0.2558 - val_loss: 0.0596 - val_tp: 65.0000 - val_fp: 634.0000 - val_tn: 44862.0000 - val_fn: 8.0000 - val_accuracy: 0.9859 - val_precision: 0.0930 - val_recall: 0.8904 - val_auc: 0.9757 - val_prc: 0.6446
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3459 - tp: 257.0000 - fp: 5798.0000 - tn: 176171.0000 - fn: 50.0000 - accuracy: 0.9679 - precision: 0.0424 - recall: 0.8371 - auc: 0.9376 - prc: 0.2224 - val_loss: 0.0683 - val_tp: 65.0000 - val_fp: 704.0000 - val_tn: 44792.0000 - val_fn: 8.0000 - val_accuracy: 0.9844 - val_precision: 0.0845 - val_recall: 0.8904 - val_auc: 0.9759 - val_prc: 0.6251
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2606 - tp: 264.0000 - fp: 6222.0000 - tn: 175747.0000 - fn: 43.0000 - accuracy: 0.9656 - precision: 0.0407 - recall: 0.8599 - auc: 0.9630 - prc: 0.2147 - val_loss: 0.0760 - val_tp: 65.0000 - val_fp: 793.0000 - val_tn: 44703.0000 - val_fn: 8.0000 - val_accuracy: 0.9824 - val_precision: 0.0758 - val_recall: 0.8904 - val_auc: 0.9763 - val_prc: 0.6065
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3085 - tp: 262.0000 - fp: 6795.0000 - tn: 175174.0000 - fn: 45.0000 - accuracy: 0.9625 - precision: 0.0371 - recall: 0.8534 - auc: 0.9468 - prc: 0.1905 - val_loss: 0.0809 - val_tp: 65.0000 - val_fp: 850.0000 - val_tn: 44646.0000 - val_fn: 8.0000 - val_accuracy: 0.9812 - val_precision: 0.0710 - val_recall: 0.8904 - val_auc: 0.9787 - val_prc: 0.5945
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2677 - tp: 271.0000 - fp: 6960.0000 - tn: 175009.0000 - fn: 36.0000 - accuracy: 0.9616 - precision: 0.0375 - recall: 0.8827 - auc: 0.9564 - prc: 0.2065 - val_loss: 0.0819 - val_tp: 66.0000 - val_fp: 865.0000 - val_tn: 44631.0000 - val_fn: 7.0000 - val_accuracy: 0.9809 - val_precision: 0.0709 - val_recall: 0.9041 - val_auc: 0.9798 - val_prc: 0.5886
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2582 - tp: 267.0000 - fp: 7029.0000 - tn: 174940.0000 - fn: 40.0000 - accuracy: 0.9612 - precision: 0.0366 - recall: 0.8697 - auc: 0.9625 - prc: 0.1952 - val_loss: 0.0874 - val_tp: 66.0000 - val_fp: 945.0000 - val_tn: 44551.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0653 - val_recall: 0.9041 - val_auc: 0.9803 - val_prc: 0.5823
Restoring model weights from the end of the best epoch.
Epoch 00014: early stopping

Consultar historial de entrenamiento

plot_metrics(weighted_history)

png

Evaluar métricas

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.018783362582325935
tp :  92.0
fp :  126.0
tn :  56724.0
fn :  20.0
accuracy :  0.9974368810653687
precision :  0.4220183491706848
recall :  0.8214285969734192
auc :  0.972480058670044
prc :  0.7655318975448608

Legitimate Transactions Detected (True Negatives):  56724
Legitimate Transactions Incorrectly Detected (False Positives):  126
Fraudulent Transactions Missed (False Negatives):  20
Fraudulent Transactions Detected (True Positives):  92
Total Fraudulent Transactions:  112

png

Aquí puede ver que con las ponderaciones de clase, la exactitud y la precisión son menores porque hay más falsos positivos, pero a la inversa, la recuperación y el AUC son más altas porque el modelo también encontró más verdaderos positivos. A pesar de tener una precisión más baja, este modelo tiene una mayor recuperación (e identifica más transacciones fraudulentas). Por supuesto, ambos tipos de error tienen un costo (tampoco querrá molestar a los usuarios marcando demasiadas transacciones legítimas como fraudulentas). Considere cuidadosamente las compensaciones entre estos diferentes tipos de errores para su aplicación.

Trazar la República de China

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f8fa447fa10>

png

Trazar el AUPRC

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f8f8c476a90>

png

Sobremuestreo

Sobremuestrear la clase minoritaria

Un enfoque relacionado sería volver a muestrear el conjunto de datos sobremuestreando la clase minoritaria.

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

Usando NumPy

Puede equilibrar el conjunto de datos manualmente eligiendo el número correcto de índices aleatorios de los ejemplos positivos:

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181969, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363938, 29)

usando tf.data

Si está utilizando tf.data la forma más fácil de producir ejemplos equilibrada es comenzar con una positive y una negative conjunto de datos, y unirlos. Ver la guía tf.data para más ejemplos.

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

Cada conjunto de datos proporciona (feature, label) pares:

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [-5.          5.         -5.          4.46719561 -5.         -3.44901198
 -5.          5.         -3.44979494 -5.          4.80280444 -5.
  0.77383648 -5.         -0.05990851 -5.         -5.         -5.
  1.15234596  2.14273815  2.39970996 -2.48422618 -1.84760481  0.25277203
  3.30008849 -0.45731082  3.52404799  1.28027245  0.82059281]

Label:  1

Combinar los dos juntos usando experimental.sample_from_datasets :

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/experimental/ops/interleave_ops.py:260: RandomDataset.__init__ (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.random(...)`.
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.501953125

Para usar este conjunto de datos, necesitará la cantidad de pasos por época.

La definición de "época" en este caso es menos clara. Supongamos que es la cantidad de lotes necesarios para ver cada ejemplo negativo una vez:

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

Entrene con los datos sobremuestreados

Ahora intente entrenar el modelo con el conjunto de datos remuestreados en lugar de usar ponderaciones de clase para ver cómo se comparan estos métodos.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks=[early_stopping],
    validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 10s 29ms/step - loss: 0.3662 - tp: 245639.0000 - fp: 60552.0000 - tn: 281108.0000 - fn: 39007.0000 - accuracy: 0.8410 - precision: 0.8022 - recall: 0.8630 - auc: 0.9323 - prc: 0.9387 - val_loss: 0.1981 - val_tp: 67.0000 - val_fp: 997.0000 - val_tn: 44499.0000 - val_fn: 6.0000 - val_accuracy: 0.9780 - val_precision: 0.0630 - val_recall: 0.9178 - val_auc: 0.9764 - val_prc: 0.7167
Epoch 2/100
278/278 [==============================] - 7s 26ms/step - loss: 0.1941 - tp: 257104.0000 - fp: 15251.0000 - tn: 269769.0000 - fn: 27220.0000 - accuracy: 0.9254 - precision: 0.9440 - recall: 0.9043 - auc: 0.9757 - prc: 0.9803 - val_loss: 0.1061 - val_tp: 67.0000 - val_fp: 766.0000 - val_tn: 44730.0000 - val_fn: 6.0000 - val_accuracy: 0.9831 - val_precision: 0.0804 - val_recall: 0.9178 - val_auc: 0.9819 - val_prc: 0.7260
Epoch 3/100
278/278 [==============================] - 7s 26ms/step - loss: 0.1496 - tp: 261767.0000 - fp: 10789.0000 - tn: 273107.0000 - fn: 23681.0000 - accuracy: 0.9395 - precision: 0.9604 - recall: 0.9170 - auc: 0.9862 - prc: 0.9878 - val_loss: 0.0789 - val_tp: 67.0000 - val_fp: 698.0000 - val_tn: 44798.0000 - val_fn: 6.0000 - val_accuracy: 0.9846 - val_precision: 0.0876 - val_recall: 0.9178 - val_auc: 0.9815 - val_prc: 0.7064
Epoch 4/100
278/278 [==============================] - 7s 27ms/step - loss: 0.1260 - tp: 265011.0000 - fp: 9194.0000 - tn: 275097.0000 - fn: 20042.0000 - accuracy: 0.9486 - precision: 0.9665 - recall: 0.9297 - auc: 0.9907 - prc: 0.9913 - val_loss: 0.0679 - val_tp: 67.0000 - val_fp: 685.0000 - val_tn: 44811.0000 - val_fn: 6.0000 - val_accuracy: 0.9848 - val_precision: 0.0891 - val_recall: 0.9178 - val_auc: 0.9798 - val_prc: 0.7004
Epoch 5/100
278/278 [==============================] - 7s 26ms/step - loss: 0.1134 - tp: 265989.0000 - fp: 8467.0000 - tn: 276634.0000 - fn: 18254.0000 - accuracy: 0.9531 - precision: 0.9691 - recall: 0.9358 - auc: 0.9927 - prc: 0.9929 - val_loss: 0.0610 - val_tp: 67.0000 - val_fp: 634.0000 - val_tn: 44862.0000 - val_fn: 6.0000 - val_accuracy: 0.9860 - val_precision: 0.0956 - val_recall: 0.9178 - val_auc: 0.9808 - val_prc: 0.6920
Epoch 6/100
278/278 [==============================] - 7s 26ms/step - loss: 0.1038 - tp: 267560.0000 - fp: 8017.0000 - tn: 277110.0000 - fn: 16657.0000 - accuracy: 0.9567 - precision: 0.9709 - recall: 0.9414 - auc: 0.9941 - prc: 0.9940 - val_loss: 0.0550 - val_tp: 67.0000 - val_fp: 589.0000 - val_tn: 44907.0000 - val_fn: 6.0000 - val_accuracy: 0.9869 - val_precision: 0.1021 - val_recall: 0.9178 - val_auc: 0.9770 - val_prc: 0.6902
Epoch 7/100
278/278 [==============================] - 7s 26ms/step - loss: 0.0963 - tp: 270477.0000 - fp: 7709.0000 - tn: 276155.0000 - fn: 15003.0000 - accuracy: 0.9601 - precision: 0.9723 - recall: 0.9474 - auc: 0.9950 - prc: 0.9948 - val_loss: 0.0507 - val_tp: 66.0000 - val_fp: 565.0000 - val_tn: 44931.0000 - val_fn: 7.0000 - val_accuracy: 0.9874 - val_precision: 0.1046 - val_recall: 0.9041 - val_auc: 0.9731 - val_prc: 0.6893
Epoch 8/100
278/278 [==============================] - 7s 26ms/step - loss: 0.0896 - tp: 270719.0000 - fp: 7589.0000 - tn: 277241.0000 - fn: 13795.0000 - accuracy: 0.9624 - precision: 0.9727 - recall: 0.9515 - auc: 0.9956 - prc: 0.9954 - val_loss: 0.0460 - val_tp: 66.0000 - val_fp: 545.0000 - val_tn: 44951.0000 - val_fn: 7.0000 - val_accuracy: 0.9879 - val_precision: 0.1080 - val_recall: 0.9041 - val_auc: 0.9691 - val_prc: 0.6827
Epoch 9/100
278/278 [==============================] - 8s 27ms/step - loss: 0.0843 - tp: 271441.0000 - fp: 7453.0000 - tn: 277829.0000 - fn: 12621.0000 - accuracy: 0.9647 - precision: 0.9733 - recall: 0.9556 - auc: 0.9961 - prc: 0.9957 - val_loss: 0.0431 - val_tp: 66.0000 - val_fp: 532.0000 - val_tn: 44964.0000 - val_fn: 7.0000 - val_accuracy: 0.9882 - val_precision: 0.1104 - val_recall: 0.9041 - val_auc: 0.9705 - val_prc: 0.6772
Epoch 10/100
278/278 [==============================] - 8s 27ms/step - loss: 0.0805 - tp: 272988.0000 - fp: 7291.0000 - tn: 277435.0000 - fn: 11630.0000 - accuracy: 0.9668 - precision: 0.9740 - recall: 0.9591 - auc: 0.9964 - prc: 0.9960 - val_loss: 0.0396 - val_tp: 66.0000 - val_fp: 512.0000 - val_tn: 44984.0000 - val_fn: 7.0000 - val_accuracy: 0.9886 - val_precision: 0.1142 - val_recall: 0.9041 - val_auc: 0.9658 - val_prc: 0.6773
Epoch 11/100
278/278 [==============================] - 7s 27ms/step - loss: 0.0767 - tp: 273819.0000 - fp: 7294.0000 - tn: 277508.0000 - fn: 10723.0000 - accuracy: 0.9684 - precision: 0.9741 - recall: 0.9623 - auc: 0.9967 - prc: 0.9963 - val_loss: 0.0373 - val_tp: 66.0000 - val_fp: 503.0000 - val_tn: 44993.0000 - val_fn: 7.0000 - val_accuracy: 0.9888 - val_precision: 0.1160 - val_recall: 0.9041 - val_auc: 0.9664 - val_prc: 0.6767
Epoch 12/100
278/278 [==============================] - 7s 27ms/step - loss: 0.0739 - tp: 275132.0000 - fp: 7228.0000 - tn: 277235.0000 - fn: 9749.0000 - accuracy: 0.9702 - precision: 0.9744 - recall: 0.9658 - auc: 0.9969 - prc: 0.9965 - val_loss: 0.0343 - val_tp: 66.0000 - val_fp: 470.0000 - val_tn: 45026.0000 - val_fn: 7.0000 - val_accuracy: 0.9895 - val_precision: 0.1231 - val_recall: 0.9041 - val_auc: 0.9668 - val_prc: 0.6762
Restoring model weights from the end of the best epoch.
Epoch 00012: early stopping

Si el proceso de entrenamiento tuviera en cuenta el conjunto de datos completo en cada actualización de gradiente, este sobremuestreo sería básicamente idéntico a la ponderación de la clase.

Pero al entrenar el modelo por lotes, como lo hizo aquí, los datos sobremuestreados proporcionan una señal de gradiente más suave: en lugar de que cada ejemplo positivo se muestre en un lote con un gran peso, se muestran en muchos lotes diferentes cada vez con un pequeño peso.

Esta señal de gradiente más suave facilita el entrenamiento del modelo.

Consultar historial de entrenamiento

Tenga en cuenta que las distribuciones de métricas serán diferentes aquí, porque los datos de entrenamiento tienen una distribución totalmente diferente de los datos de validación y prueba.

plot_metrics(resampled_history)

png

Volver a entrenar

Debido a que el entrenamiento es más fácil con los datos balanceados, el procedimiento de entrenamiento anterior puede sobreajustarse rápidamente.

Así que romper las épocas para dar la tf.keras.callbacks.EarlyStopping un mayor control sobre cuándo dejar de entrenar.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch=20,
    epochs=10*EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_ds))
Epoch 1/1000
20/20 [==============================] - 3s 68ms/step - loss: 0.6946 - tp: 14386.0000 - fp: 10452.0000 - tn: 55426.0000 - fn: 6265.0000 - accuracy: 0.8068 - precision: 0.5792 - recall: 0.6966 - auc: 0.8975 - prc: 0.7453 - val_loss: 0.7114 - val_tp: 64.0000 - val_fp: 20870.0000 - val_tn: 24626.0000 - val_fn: 9.0000 - val_accuracy: 0.5418 - val_precision: 0.0031 - val_recall: 0.8767 - val_auc: 0.8921 - val_prc: 0.1605
Epoch 2/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.5451 - tp: 16759.0000 - fp: 8992.0000 - tn: 11553.0000 - fn: 3656.0000 - accuracy: 0.6912 - precision: 0.6508 - recall: 0.8209 - auc: 0.8250 - prc: 0.8660 - val_loss: 0.6286 - val_tp: 67.0000 - val_fp: 15830.0000 - val_tn: 29666.0000 - val_fn: 6.0000 - val_accuracy: 0.6525 - val_precision: 0.0042 - val_recall: 0.9178 - val_auc: 0.9345 - val_prc: 0.4115
Epoch 3/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.4722 - tp: 17553.0000 - fp: 7543.0000 - tn: 12817.0000 - fn: 3047.0000 - accuracy: 0.7415 - precision: 0.6994 - recall: 0.8521 - auc: 0.8730 - prc: 0.9057 - val_loss: 0.5467 - val_tp: 68.0000 - val_fp: 10829.0000 - val_tn: 34667.0000 - val_fn: 5.0000 - val_accuracy: 0.7623 - val_precision: 0.0062 - val_recall: 0.9315 - val_auc: 0.9502 - val_prc: 0.5284
Epoch 4/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.4315 - tp: 17523.0000 - fp: 6460.0000 - tn: 14156.0000 - fn: 2821.0000 - accuracy: 0.7734 - precision: 0.7306 - recall: 0.8613 - auc: 0.8947 - prc: 0.9217 - val_loss: 0.4765 - val_tp: 68.0000 - val_fp: 6977.0000 - val_tn: 38519.0000 - val_fn: 5.0000 - val_accuracy: 0.8468 - val_precision: 0.0097 - val_recall: 0.9315 - val_auc: 0.9561 - val_prc: 0.5963
Epoch 5/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.3909 - tp: 17737.0000 - fp: 5141.0000 - tn: 15380.0000 - fn: 2702.0000 - accuracy: 0.8085 - precision: 0.7753 - recall: 0.8678 - auc: 0.9118 - prc: 0.9359 - val_loss: 0.4196 - val_tp: 68.0000 - val_fp: 4514.0000 - val_tn: 40982.0000 - val_fn: 5.0000 - val_accuracy: 0.9008 - val_precision: 0.0148 - val_recall: 0.9315 - val_auc: 0.9600 - val_prc: 0.6291
Epoch 6/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.3580 - tp: 17998.0000 - fp: 4183.0000 - tn: 16117.0000 - fn: 2662.0000 - accuracy: 0.8329 - precision: 0.8114 - recall: 0.8712 - auc: 0.9243 - prc: 0.9457 - val_loss: 0.3752 - val_tp: 68.0000 - val_fp: 3155.0000 - val_tn: 42341.0000 - val_fn: 5.0000 - val_accuracy: 0.9307 - val_precision: 0.0211 - val_recall: 0.9315 - val_auc: 0.9628 - val_prc: 0.6393
Epoch 7/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.3351 - tp: 17892.0000 - fp: 3636.0000 - tn: 16909.0000 - fn: 2523.0000 - accuracy: 0.8496 - precision: 0.8311 - recall: 0.8764 - auc: 0.9339 - prc: 0.9515 - val_loss: 0.3395 - val_tp: 67.0000 - val_fp: 2475.0000 - val_tn: 43021.0000 - val_fn: 6.0000 - val_accuracy: 0.9456 - val_precision: 0.0264 - val_recall: 0.9178 - val_auc: 0.9651 - val_prc: 0.6572
Epoch 8/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.3159 - tp: 18122.0000 - fp: 2971.0000 - tn: 17345.0000 - fn: 2522.0000 - accuracy: 0.8659 - precision: 0.8591 - recall: 0.8778 - auc: 0.9397 - prc: 0.9563 - val_loss: 0.3089 - val_tp: 67.0000 - val_fp: 1942.0000 - val_tn: 43554.0000 - val_fn: 6.0000 - val_accuracy: 0.9573 - val_precision: 0.0333 - val_recall: 0.9178 - val_auc: 0.9674 - val_prc: 0.6870
Epoch 9/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2961 - tp: 18169.0000 - fp: 2566.0000 - tn: 17880.0000 - fn: 2345.0000 - accuracy: 0.8801 - precision: 0.8762 - recall: 0.8857 - auc: 0.9465 - prc: 0.9614 - val_loss: 0.2849 - val_tp: 67.0000 - val_fp: 1683.0000 - val_tn: 43813.0000 - val_fn: 6.0000 - val_accuracy: 0.9629 - val_precision: 0.0383 - val_recall: 0.9178 - val_auc: 0.9698 - val_prc: 0.6957
Epoch 10/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2832 - tp: 18368.0000 - fp: 2309.0000 - tn: 17928.0000 - fn: 2355.0000 - accuracy: 0.8861 - precision: 0.8883 - recall: 0.8864 - auc: 0.9502 - prc: 0.9640 - val_loss: 0.2642 - val_tp: 67.0000 - val_fp: 1461.0000 - val_tn: 44035.0000 - val_fn: 6.0000 - val_accuracy: 0.9678 - val_precision: 0.0438 - val_recall: 0.9178 - val_auc: 0.9718 - val_prc: 0.7029
Epoch 11/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2712 - tp: 18303.0000 - fp: 2138.0000 - tn: 18247.0000 - fn: 2272.0000 - accuracy: 0.8923 - precision: 0.8954 - recall: 0.8896 - auc: 0.9543 - prc: 0.9664 - val_loss: 0.2452 - val_tp: 67.0000 - val_fp: 1274.0000 - val_tn: 44222.0000 - val_fn: 6.0000 - val_accuracy: 0.9719 - val_precision: 0.0500 - val_recall: 0.9178 - val_auc: 0.9731 - val_prc: 0.7022
Epoch 12/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2553 - tp: 18294.0000 - fp: 1794.0000 - tn: 18668.0000 - fn: 2204.0000 - accuracy: 0.9024 - precision: 0.9107 - recall: 0.8925 - auc: 0.9598 - prc: 0.9698 - val_loss: 0.2281 - val_tp: 67.0000 - val_fp: 1144.0000 - val_tn: 44352.0000 - val_fn: 6.0000 - val_accuracy: 0.9748 - val_precision: 0.0553 - val_recall: 0.9178 - val_auc: 0.9740 - val_prc: 0.7053
Epoch 13/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2493 - tp: 18208.0000 - fp: 1770.0000 - tn: 18781.0000 - fn: 2201.0000 - accuracy: 0.9031 - precision: 0.9114 - recall: 0.8922 - auc: 0.9608 - prc: 0.9702 - val_loss: 0.2139 - val_tp: 67.0000 - val_fp: 1091.0000 - val_tn: 44405.0000 - val_fn: 6.0000 - val_accuracy: 0.9759 - val_precision: 0.0579 - val_recall: 0.9178 - val_auc: 0.9752 - val_prc: 0.7139
Epoch 14/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2385 - tp: 18355.0000 - fp: 1625.0000 - tn: 18873.0000 - fn: 2107.0000 - accuracy: 0.9089 - precision: 0.9187 - recall: 0.8970 - auc: 0.9641 - prc: 0.9727 - val_loss: 0.2001 - val_tp: 67.0000 - val_fp: 1033.0000 - val_tn: 44463.0000 - val_fn: 6.0000 - val_accuracy: 0.9772 - val_precision: 0.0609 - val_recall: 0.9178 - val_auc: 0.9758 - val_prc: 0.7173
Epoch 15/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2329 - tp: 18289.0000 - fp: 1523.0000 - tn: 19026.0000 - fn: 2122.0000 - accuracy: 0.9110 - precision: 0.9231 - recall: 0.8960 - auc: 0.9655 - prc: 0.9734 - val_loss: 0.1877 - val_tp: 67.0000 - val_fp: 971.0000 - val_tn: 44525.0000 - val_fn: 6.0000 - val_accuracy: 0.9786 - val_precision: 0.0645 - val_recall: 0.9178 - val_auc: 0.9763 - val_prc: 0.7148
Epoch 16/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2226 - tp: 18563.0000 - fp: 1321.0000 - tn: 18974.0000 - fn: 2102.0000 - accuracy: 0.9164 - precision: 0.9336 - recall: 0.8983 - auc: 0.9680 - prc: 0.9756 - val_loss: 0.1783 - val_tp: 67.0000 - val_fp: 975.0000 - val_tn: 44521.0000 - val_fn: 6.0000 - val_accuracy: 0.9785 - val_precision: 0.0643 - val_recall: 0.9178 - val_auc: 0.9771 - val_prc: 0.7154
Epoch 17/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2158 - tp: 18569.0000 - fp: 1310.0000 - tn: 19047.0000 - fn: 2034.0000 - accuracy: 0.9184 - precision: 0.9341 - recall: 0.9013 - auc: 0.9701 - prc: 0.9770 - val_loss: 0.1698 - val_tp: 67.0000 - val_fp: 990.0000 - val_tn: 44506.0000 - val_fn: 6.0000 - val_accuracy: 0.9781 - val_precision: 0.0634 - val_recall: 0.9178 - val_auc: 0.9780 - val_prc: 0.7211
Epoch 18/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2110 - tp: 18643.0000 - fp: 1231.0000 - tn: 19057.0000 - fn: 2029.0000 - accuracy: 0.9204 - precision: 0.9381 - recall: 0.9018 - auc: 0.9718 - prc: 0.9778 - val_loss: 0.1614 - val_tp: 67.0000 - val_fp: 976.0000 - val_tn: 44520.0000 - val_fn: 6.0000 - val_accuracy: 0.9785 - val_precision: 0.0642 - val_recall: 0.9178 - val_auc: 0.9783 - val_prc: 0.7265
Epoch 19/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2055 - tp: 18684.0000 - fp: 1145.0000 - tn: 19169.0000 - fn: 1962.0000 - accuracy: 0.9241 - precision: 0.9423 - recall: 0.9050 - auc: 0.9732 - prc: 0.9790 - val_loss: 0.1532 - val_tp: 67.0000 - val_fp: 947.0000 - val_tn: 44549.0000 - val_fn: 6.0000 - val_accuracy: 0.9791 - val_precision: 0.0661 - val_recall: 0.9178 - val_auc: 0.9790 - val_prc: 0.7311
Epoch 20/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2009 - tp: 18468.0000 - fp: 1123.0000 - tn: 19349.0000 - fn: 2020.0000 - accuracy: 0.9233 - precision: 0.9427 - recall: 0.9014 - auc: 0.9738 - prc: 0.9790 - val_loss: 0.1464 - val_tp: 67.0000 - val_fp: 922.0000 - val_tn: 44574.0000 - val_fn: 6.0000 - val_accuracy: 0.9796 - val_precision: 0.0677 - val_recall: 0.9178 - val_auc: 0.9792 - val_prc: 0.7351
Epoch 21/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.1941 - tp: 18454.0000 - fp: 1111.0000 - tn: 19450.0000 - fn: 1945.0000 - accuracy: 0.9254 - precision: 0.9432 - recall: 0.9047 - auc: 0.9756 - prc: 0.9803 - val_loss: 0.1398 - val_tp: 67.0000 - val_fp: 906.0000 - val_tn: 44590.0000 - val_fn: 6.0000 - val_accuracy: 0.9800 - val_precision: 0.0689 - val_recall: 0.9178 - val_auc: 0.9795 - val_prc: 0.7375
Epoch 22/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.1907 - tp: 18512.0000 - fp: 1068.0000 - tn: 19474.0000 - fn: 1906.0000 - accuracy: 0.9274 - precision: 0.9455 - recall: 0.9067 - auc: 0.9768 - prc: 0.9810 - val_loss: 0.1331 - val_tp: 67.0000 - val_fp: 842.0000 - val_tn: 44654.0000 - val_fn: 6.0000 - val_accuracy: 0.9814 - val_precision: 0.0737 - val_recall: 0.9178 - val_auc: 0.9796 - val_prc: 0.7422
Epoch 23/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1848 - tp: 18486.0000 - fp: 978.0000 - tn: 19564.0000 - fn: 1932.0000 - accuracy: 0.9290 - precision: 0.9498 - recall: 0.9054 - auc: 0.9782 - prc: 0.9819 - val_loss: 0.1275 - val_tp: 67.0000 - val_fp: 823.0000 - val_tn: 44673.0000 - val_fn: 6.0000 - val_accuracy: 0.9818 - val_precision: 0.0753 - val_recall: 0.9178 - val_auc: 0.9802 - val_prc: 0.7448
Epoch 24/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1825 - tp: 18489.0000 - fp: 948.0000 - tn: 19605.0000 - fn: 1918.0000 - accuracy: 0.9300 - precision: 0.9512 - recall: 0.9060 - auc: 0.9785 - prc: 0.9821 - val_loss: 0.1225 - val_tp: 67.0000 - val_fp: 813.0000 - val_tn: 44683.0000 - val_fn: 6.0000 - val_accuracy: 0.9820 - val_precision: 0.0761 - val_recall: 0.9178 - val_auc: 0.9804 - val_prc: 0.7453
Epoch 25/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1763 - tp: 18693.0000 - fp: 931.0000 - tn: 19429.0000 - fn: 1907.0000 - accuracy: 0.9307 - precision: 0.9526 - recall: 0.9074 - auc: 0.9797 - prc: 0.9834 - val_loss: 0.1191 - val_tp: 67.0000 - val_fp: 828.0000 - val_tn: 44668.0000 - val_fn: 6.0000 - val_accuracy: 0.9817 - val_precision: 0.0749 - val_recall: 0.9178 - val_auc: 0.9806 - val_prc: 0.7456
Epoch 26/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1743 - tp: 18703.0000 - fp: 894.0000 - tn: 19455.0000 - fn: 1908.0000 - accuracy: 0.9316 - precision: 0.9544 - recall: 0.9074 - auc: 0.9809 - prc: 0.9840 - val_loss: 0.1157 - val_tp: 67.0000 - val_fp: 839.0000 - val_tn: 44657.0000 - val_fn: 6.0000 - val_accuracy: 0.9815 - val_precision: 0.0740 - val_recall: 0.9178 - val_auc: 0.9813 - val_prc: 0.7360
Epoch 27/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1672 - tp: 18932.0000 - fp: 892.0000 - tn: 19335.0000 - fn: 1801.0000 - accuracy: 0.9343 - precision: 0.9550 - recall: 0.9131 - auc: 0.9825 - prc: 0.9855 - val_loss: 0.1115 - val_tp: 67.0000 - val_fp: 813.0000 - val_tn: 44683.0000 - val_fn: 6.0000 - val_accuracy: 0.9820 - val_precision: 0.0761 - val_recall: 0.9178 - val_auc: 0.9813 - val_prc: 0.7263
Epoch 28/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1691 - tp: 18568.0000 - fp: 876.0000 - tn: 19651.0000 - fn: 1865.0000 - accuracy: 0.9331 - precision: 0.9549 - recall: 0.9087 - auc: 0.9822 - prc: 0.9846 - val_loss: 0.1075 - val_tp: 67.0000 - val_fp: 799.0000 - val_tn: 44697.0000 - val_fn: 6.0000 - val_accuracy: 0.9823 - val_precision: 0.0774 - val_recall: 0.9178 - val_auc: 0.9815 - val_prc: 0.7260
Epoch 29/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1652 - tp: 18595.0000 - fp: 887.0000 - tn: 19640.0000 - fn: 1838.0000 - accuracy: 0.9335 - precision: 0.9545 - recall: 0.9100 - auc: 0.9831 - prc: 0.9853 - val_loss: 0.1035 - val_tp: 67.0000 - val_fp: 758.0000 - val_tn: 44738.0000 - val_fn: 6.0000 - val_accuracy: 0.9832 - val_precision: 0.0812 - val_recall: 0.9178 - val_auc: 0.9815 - val_prc: 0.7281
Epoch 30/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1602 - tp: 18632.0000 - fp: 800.0000 - tn: 19708.0000 - fn: 1820.0000 - accuracy: 0.9360 - precision: 0.9588 - recall: 0.9110 - auc: 0.9839 - prc: 0.9861 - val_loss: 0.1000 - val_tp: 67.0000 - val_fp: 712.0000 - val_tn: 44784.0000 - val_fn: 6.0000 - val_accuracy: 0.9842 - val_precision: 0.0860 - val_recall: 0.9178 - val_auc: 0.9821 - val_prc: 0.7190
Epoch 31/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1603 - tp: 18561.0000 - fp: 851.0000 - tn: 19711.0000 - fn: 1837.0000 - accuracy: 0.9344 - precision: 0.9562 - recall: 0.9099 - auc: 0.9837 - prc: 0.9859 - val_loss: 0.0973 - val_tp: 67.0000 - val_fp: 716.0000 - val_tn: 44780.0000 - val_fn: 6.0000 - val_accuracy: 0.9842 - val_precision: 0.0856 - val_recall: 0.9178 - val_auc: 0.9819 - val_prc: 0.7213
Epoch 32/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1554 - tp: 18548.0000 - fp: 793.0000 - tn: 19817.0000 - fn: 1802.0000 - accuracy: 0.9366 - precision: 0.9590 - recall: 0.9114 - auc: 0.9851 - prc: 0.9868 - val_loss: 0.0950 - val_tp: 67.0000 - val_fp: 719.0000 - val_tn: 44777.0000 - val_fn: 6.0000 - val_accuracy: 0.9841 - val_precision: 0.0852 - val_recall: 0.9178 - val_auc: 0.9820 - val_prc: 0.7237
Epoch 33/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1540 - tp: 18700.0000 - fp: 787.0000 - tn: 19692.0000 - fn: 1781.0000 - accuracy: 0.9373 - precision: 0.9596 - recall: 0.9130 - auc: 0.9855 - prc: 0.9871 - val_loss: 0.0934 - val_tp: 67.0000 - val_fp: 733.0000 - val_tn: 44763.0000 - val_fn: 6.0000 - val_accuracy: 0.9838 - val_precision: 0.0838 - val_recall: 0.9178 - val_auc: 0.9819 - val_prc: 0.7235
Epoch 34/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1513 - tp: 18732.0000 - fp: 775.0000 - tn: 19728.0000 - fn: 1725.0000 - accuracy: 0.9390 - precision: 0.9603 - recall: 0.9157 - auc: 0.9859 - prc: 0.9875 - val_loss: 0.0916 - val_tp: 67.0000 - val_fp: 743.0000 - val_tn: 44753.0000 - val_fn: 6.0000 - val_accuracy: 0.9836 - val_precision: 0.0827 - val_recall: 0.9178 - val_auc: 0.9820 - val_prc: 0.7236
Epoch 35/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1515 - tp: 18653.0000 - fp: 782.0000 - tn: 19793.0000 - fn: 1732.0000 - accuracy: 0.9386 - precision: 0.9598 - recall: 0.9150 - auc: 0.9860 - prc: 0.9874 - val_loss: 0.0892 - val_tp: 67.0000 - val_fp: 727.0000 - val_tn: 44769.0000 - val_fn: 6.0000 - val_accuracy: 0.9839 - val_precision: 0.0844 - val_recall: 0.9178 - val_auc: 0.9821 - val_prc: 0.7238
Restoring model weights from the end of the best epoch.
Epoch 00035: early stopping

Vuelva a verificar el historial de entrenamiento

plot_metrics(resampled_history)

png

Evaluar métricas

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
loss :  0.1160934641957283
tp :  102.0
fp :  1007.0
tn :  55843.0
fn :  10.0
accuracy :  0.9821459650993347
precision :  0.09197475016117096
recall :  0.9107142686843872
auc :  0.9811792373657227
prc :  0.809499204158783

Legitimate Transactions Detected (True Negatives):  55843
Legitimate Transactions Incorrectly Detected (False Positives):  1007
Fraudulent Transactions Missed (False Negatives):  10
Fraudulent Transactions Detected (True Positives):  102
Total Fraudulent Transactions:  112

png

Trazar la República de China

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f8f8c1e3150>

png

Trazar el AUPRC

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_prc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_prc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f8f7d6a0fd0>

png

Aplicando este tutorial a su problema

La clasificación de datos desequilibrada es una tarea intrínsecamente difícil, ya que hay muy pocas muestras de las que aprender. Siempre debe comenzar con los datos primero y hacer todo lo posible para recopilar tantas muestras como sea posible y pensar en profundidad sobre qué características pueden ser relevantes para que el modelo pueda aprovechar al máximo su clase minoritaria. En algún momento, su modelo puede tener dificultades para mejorar y producir los resultados que desea, por lo que es importante tener en cuenta el contexto de su problema y las compensaciones entre los diferentes tipos de errores.