Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Clasificación sobre datos desequilibrados

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Este tutorial demuestra cómo clasificar un conjunto de datos altamente desequilibrado en el que la cantidad de ejemplos en una clase supera en gran medida a los ejemplos en otra. Trabajará con el conjunto de datos de detección de fraude de tarjetas de crédito alojado en Kaggle. El objetivo es detectar tan solo 492 transacciones fraudulentas de un total de 284.807 transacciones. Utilizará Keras para definir el modelo y los pesos de clase para ayudar al modelo a aprender de los datos desequilibrados. .

Este tutorial contiene código completo para:

  • Cargue un archivo CSV usando Pandas.
  • Cree conjuntos de entrenamiento, validación y prueba.
  • Defina y entrene un modelo usando Keras (incluido el establecimiento de pesos de clase).
  • Evalúe el modelo utilizando varias métricas (incluida la precisión y la recuperación).
  • Pruebe técnicas comunes para lidiar con datos desequilibrados como:
    • Ponderación de clase
    • Sobremuestreo

Preparar

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

Procesamiento y exploración de datos

Descargue el conjunto de datos de fraude de tarjetas de crédito de Kaggle

Pandas es una biblioteca de Python con muchas utilidades útiles para cargar y trabajar con datos estructurados y se puede usar para descargar archivos CSV en un marco de datos.

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

Examinar el desequilibrio de la etiqueta de clase

Veamos el desequilibrio del conjunto de datos:

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)


Esto muestra la pequeña fracción de muestras positivas.

Limpiar, dividir y normalizar los datos

Los datos sin procesar tienen algunos problemas. Primero, las columnas Time y Amount son demasiado variables para usarlas directamente. Suelta la columna de Time (ya que no está claro qué significa) y toma el registro de la columna de Amount para reducir su rango.

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)

Divida el conjunto de datos en conjuntos de entrenamiento, validación y prueba. El conjunto de validación se utiliza durante el ajuste del modelo para evaluar la pérdida y cualquier métrica; sin embargo, el modelo no se ajusta a estos datos. El conjunto de prueba no se usa por completo durante la fase de entrenamiento y solo se usa al final para evaluar qué tan bien se generaliza el modelo a nuevos datos. Esto es especialmente importante con conjuntos de datos desequilibrados donde el sobreajuste es una preocupación significativa por la falta de datos de entrenamiento.

# Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

Normalice las funciones de entrada con sklearn StandardScaler. Esto establecerá la media en 0 y la desviación estándar en 1.

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

Mira la distribución de datos

A continuación, compare las distribuciones de los ejemplos positivos y negativos de algunas características. Buenas preguntas que debe hacerse en este punto son:

  • ¿Tienen sentido estas distribuciones?
    • Si. Ha normalizado la entrada y estos se concentran principalmente en el rango +/- 2 .
  • ¿Puedes ver la diferencia entre las distribuciones?
    • Sí, los ejemplos positivos contienen una tasa mucho más alta de valores extremos.
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning

png

png

Definir el modelo y las métricas

Defina una función que cree una red neuronal simple con una capa oculta densamente conectada, una capa de abandono para reducir el sobreajuste y una capa sigmoidea de salida que devuelva la probabilidad de que una transacción sea fraudulenta:

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics=METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

Comprender métricas útiles

Tenga en cuenta que hay algunas métricas definidas anteriormente que el modelo puede calcular y que serán útiles al evaluar el rendimiento.

  • Los falsos negativos y los falsos positivos son muestras que se clasificaron incorrectamente
  • Los verdaderos negativos y verdaderos positivos son muestras que se clasificaron correctamente
  • La precisión es el porcentaje de ejemplos clasificados correctamente> $ \ frac {\ text {muestras verdaderas}} {\ text {muestras totales}} $
  • La precisión es el porcentaje de positivos pronosticados que fueron clasificados correctamente> $ \ frac {\ text {verdaderos positivos}} {\ text {verdaderos positivos + falsos positivos}} $
  • El recuerdo es el porcentaje de positivos reales que se clasificaron correctamente> $ \ frac {\ text {verdaderos positivos}} {\ text {verdaderos positivos + falsos negativos}} $
  • AUC se refiere al área bajo la curva de una curva característica de funcionamiento del receptor (ROC-AUC). Esta métrica es igual a la probabilidad de que un clasificador clasifique una muestra aleatoria positiva más alta que una muestra aleatoria negativa.

Lee mas:

Modelo de línea de base

Construye el modelo

Ahora cree y entrene su modelo usando la función que se definió anteriormente. Tenga en cuenta que el modelo se ajusta con un tamaño de lote mayor que el predeterminado de 2048, esto es importante para garantizar que cada lote tenga una probabilidad decente de contener algunas muestras positivas. Si el tamaño del lote fuera demasiado pequeño, probablemente no tendrían transacciones fraudulentas de las que aprender.

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________

Prueba de ejecución del modelo:

model.predict(train_features[:10])
array([[0.6983068 ],
       [0.7546284 ],
       [0.73785573],
       [0.7908986 ],
       [0.51232255],
       [0.752192  ],
       [0.7387281 ],
       [0.9410955 ],
       [0.809352  ],
       [0.6911539 ]], dtype=float32)

Opcional: establezca el sesgo inicial correcto.

Estas suposiciones iniciales no son buenas. Sabes que el conjunto de datos está desequilibrado. Establezca el sesgo de la capa de salida para reflejar eso (Ver: Una receta para entrenar redes neuronales: "init well" ). Esto puede ayudar con la convergencia inicial.

Con la inicialización de sesgo predeterminada, la pérdida debe ser de aproximadamente math.log(2) = 0.69314

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 1.5998

El sesgo correcto para establecer se puede derivar de:

$$ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) $$
$$ b_0 = -log_e(1/p_0 - 1) $$
$$ b_0 = log_e(pos/neg)$$
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

Establezca eso como el sesgo inicial, y el modelo dará conjeturas iniciales mucho más razonables.

Debe estar cerca de: pos/total = 0.0018

model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
array([[0.00168876],
       [0.00081124],
       [0.00087036],
       [0.00241473],
       [0.00133016],
       [0.00121771],
       [0.00079989],
       [0.00079692],
       [0.00257652],
       [0.00104385]], dtype=float32)

Con esta inicialización, la pérdida inicial debería ser aproximadamente:

$$-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317$$
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0174

Esta pérdida inicial es aproximadamente 50 veces menor que si hubiera sido con una inicialización ingenua.

De esta manera, el modelo no necesita pasar las primeras épocas simplemente aprendiendo que los ejemplos positivos son poco probables. Esto también facilita la lectura de gráficos de la pérdida durante el entrenamiento.

Punto de control de los pesos iniciales

Para que las distintas ejecuciones de entrenamiento sean más comparables, mantenga los pesos de este modelo inicial en un archivo de punto de control y cárguelos en cada modelo antes del entrenamiento.

initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)

Confirme que la corrección de sesgo ayuda

Antes de continuar, confirme rápidamente que la inicialización cuidadosa del sesgo realmente ayudó.

Entrene el modelo durante 20 épocas, con y sin esta inicialización cuidadosa, y compare las pérdidas:

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale on y-axis to show the wide range of values.
  plt.semilogy(history.epoch, history.history['loss'],
               color=colors[n], label='Train ' + label)
  plt.semilogy(history.epoch, history.history['val_loss'],
               color=colors[n], label='Val ' + label,
               linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

png

La figura anterior lo deja claro: en términos de pérdida de validación, en este problema, esta inicialización cuidadosa da una clara ventaja.

Entrena el modelo

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 3s 15ms/step - loss: 0.0186 - tp: 64.0000 - fp: 25.1978 - tn: 139431.9780 - fn: 188.3956 - accuracy: 0.9985 - precision: 0.7214 - recall: 0.3037 - auc: 0.6752 - val_loss: 0.0109 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 45489.0000 - val_fn: 80.0000 - val_accuracy: 0.9982 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.6977
Epoch 2/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0085 - tp: 25.4505 - fp: 5.6703 - tn: 93973.1538 - fn: 136.2967 - accuracy: 0.9985 - precision: 0.6137 - recall: 0.1108 - auc: 0.8194 - val_loss: 0.0051 - val_tp: 27.0000 - val_fp: 9.0000 - val_tn: 45480.0000 - val_fn: 53.0000 - val_accuracy: 0.9986 - val_precision: 0.7500 - val_recall: 0.3375 - val_auc: 0.9184
Epoch 3/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0057 - tp: 68.4505 - fp: 9.9231 - tn: 93970.2198 - fn: 91.9780 - accuracy: 0.9989 - precision: 0.8896 - recall: 0.4278 - auc: 0.9133 - val_loss: 0.0043 - val_tp: 44.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 36.0000 - val_accuracy: 0.9990 - val_precision: 0.8148 - val_recall: 0.5500 - val_auc: 0.9185
Epoch 4/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0051 - tp: 86.9451 - fp: 17.4505 - tn: 93959.1429 - fn: 77.0330 - accuracy: 0.9990 - precision: 0.8422 - recall: 0.5397 - auc: 0.9191 - val_loss: 0.0040 - val_tp: 51.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 29.0000 - val_accuracy: 0.9991 - val_precision: 0.8361 - val_recall: 0.6375 - val_auc: 0.9248
Epoch 5/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0053 - tp: 83.8571 - fp: 13.0110 - tn: 93965.0879 - fn: 78.6154 - accuracy: 0.9990 - precision: 0.8890 - recall: 0.5114 - auc: 0.9212 - val_loss: 0.0039 - val_tp: 55.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 25.0000 - val_accuracy: 0.9992 - val_precision: 0.8462 - val_recall: 0.6875 - val_auc: 0.9247
Epoch 6/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0049 - tp: 83.9890 - fp: 11.9451 - tn: 93976.1648 - fn: 68.4725 - accuracy: 0.9992 - precision: 0.8793 - recall: 0.5693 - auc: 0.9022 - val_loss: 0.0038 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8406 - val_recall: 0.7250 - val_auc: 0.9247
Epoch 7/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0053 - tp: 89.0549 - fp: 19.6484 - tn: 93958.4835 - fn: 73.3846 - accuracy: 0.9991 - precision: 0.8187 - recall: 0.5711 - auc: 0.8824 - val_loss: 0.0036 - val_tp: 53.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 27.0000 - val_accuracy: 0.9992 - val_precision: 0.8413 - val_recall: 0.6625 - val_auc: 0.9309
Epoch 8/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0040 - tp: 92.4396 - fp: 11.4835 - tn: 93970.6374 - fn: 66.0110 - accuracy: 0.9992 - precision: 0.9188 - recall: 0.5617 - auc: 0.9298 - val_loss: 0.0036 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8529 - val_recall: 0.7250 - val_auc: 0.9309
Epoch 9/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 100.4396 - fp: 14.7802 - tn: 93963.6044 - fn: 61.7473 - accuracy: 0.9992 - precision: 0.8725 - recall: 0.6228 - auc: 0.9167 - val_loss: 0.0035 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8406 - val_recall: 0.7250 - val_auc: 0.9247
Epoch 10/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0046 - tp: 94.8132 - fp: 12.7253 - tn: 93958.3187 - fn: 74.7143 - accuracy: 0.9991 - precision: 0.8935 - recall: 0.5710 - auc: 0.9196 - val_loss: 0.0035 - val_tp: 61.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8472 - val_recall: 0.7625 - val_auc: 0.9247
Epoch 11/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0047 - tp: 97.0440 - fp: 16.3956 - tn: 93958.9231 - fn: 68.2088 - accuracy: 0.9991 - precision: 0.8591 - recall: 0.5946 - auc: 0.9160 - val_loss: 0.0034 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8529 - val_recall: 0.7250 - val_auc: 0.9247
Epoch 12/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0040 - tp: 101.0769 - fp: 13.1868 - tn: 93958.7143 - fn: 67.5934 - accuracy: 0.9992 - precision: 0.8799 - recall: 0.6127 - auc: 0.9202 - val_loss: 0.0034 - val_tp: 61.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8472 - val_recall: 0.7625 - val_auc: 0.9247
Epoch 13/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0043 - tp: 98.0769 - fp: 16.0330 - tn: 93961.6703 - fn: 64.7912 - accuracy: 0.9992 - precision: 0.8536 - recall: 0.6112 - auc: 0.9154 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 9.0000 - val_tn: 45480.0000 - val_fn: 21.0000 - val_accuracy: 0.9993 - val_precision: 0.8676 - val_recall: 0.7375 - val_auc: 0.9247
Epoch 14/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0050 - tp: 93.5495 - fp: 15.4286 - tn: 93961.4615 - fn: 70.1319 - accuracy: 0.9991 - precision: 0.8590 - recall: 0.5563 - auc: 0.8916 - val_loss: 0.0033 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247
Epoch 15/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0042 - tp: 90.2198 - fp: 15.4286 - tn: 93968.0989 - fn: 66.8242 - accuracy: 0.9992 - precision: 0.8524 - recall: 0.5813 - auc: 0.9270 - val_loss: 0.0033 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247
Epoch 16/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 96.4835 - fp: 14.7363 - tn: 93960.4396 - fn: 68.9121 - accuracy: 0.9991 - precision: 0.8754 - recall: 0.5727 - auc: 0.9218 - val_loss: 0.0033 - val_tp: 62.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8611 - val_recall: 0.7750 - val_auc: 0.9247
Epoch 17/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0044 - tp: 98.7033 - fp: 16.4835 - tn: 93958.9231 - fn: 66.4615 - accuracy: 0.9991 - precision: 0.8674 - recall: 0.5985 - auc: 0.9108 - val_loss: 0.0032 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247
Restoring model weights from the end of the best epoch.
Epoch 00017: early stopping

Consultar historial de entrenamiento

En esta sección, producirá gráficos de la precisión y pérdida de su modelo en el conjunto de entrenamiento y validación. Estos son útiles para verificar el sobreajuste, sobre el cual puede obtener más información en este tutorial .

Además, puede producir estos gráficos para cualquiera de las métricas que creó anteriormente. Los falsos negativos se incluyen como ejemplo.

def plot_metrics(history):
  metrics = ['loss', 'auc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

png

Evaluar métricas

Puede utilizar una matriz de confusión para resumir las etiquetas reales frente a las predichas, donde el eje X es la etiqueta predicha y el eje Y es la etiqueta real.

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

Evalúe su modelo en el conjunto de datos de prueba y muestre los resultados de las métricas que creó anteriormente.

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.0034731989726424217
tp :  56.0
fp :  12.0
tn :  56855.0
fn :  39.0
accuracy :  0.9991046786308289
precision :  0.8235294222831726
recall :  0.5894736647605896
auc :  0.9418253898620605

Legitimate Transactions Detected (True Negatives):  56855
Legitimate Transactions Incorrectly Detected (False Positives):  12
Fraudulent Transactions Missed (False Negatives):  39
Fraudulent Transactions Detected (True Positives):  56
Total Fraudulent Transactions:  95

png

Si el modelo hubiera predicho todo perfectamente, esta sería una matriz diagonal donde los valores fuera de la diagonal principal, que indican predicciones incorrectas, serían cero. En este caso, la matriz muestra que tiene relativamente pocos falsos positivos, lo que significa que hubo relativamente pocas transacciones legítimas que se marcaron incorrectamente. Sin embargo, es probable que desee tener incluso menos falsos negativos a pesar del costo de aumentar el número de falsos positivos. Esta compensación puede ser preferible porque los falsos negativos permitirían que se realicen transacciones fraudulentas, mientras que los falsos positivos pueden hacer que se envíe un correo electrónico a un cliente para pedirle que verifique la actividad de su tarjeta.

Trazar la República de China

Ahora traza la República de China . Este gráfico es útil porque muestra, de un vistazo, el rango de rendimiento que puede alcanzar el modelo simplemente ajustando el umbral de salida.

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa4786ca748>

png

Parece que la precisión es relativamente alta, pero la recuperación y el área bajo la curva ROC (AUC) no son tan altas como le gustaría. Los clasificadores a menudo enfrentan desafíos cuando intentan maximizar tanto la precisión como la recuperación, lo cual es especialmente cierto cuando se trabaja con conjuntos de datos desequilibrados. Es importante considerar los costos de los diferentes tipos de errores en el contexto del problema que le preocupa. En este ejemplo, un falso negativo (se pierde una transacción fraudulenta) puede tener un costo financiero, mientras que un falso positivo (una transacción se marca incorrectamente como fraudulenta) puede disminuir la felicidad del usuario.

Pesos de clase

Calcular pesos de clase

El objetivo es identificar transacciones fraudulentas, pero no tiene muchas de esas muestras positivas con las que trabajar, por lo que le conviene que el clasificador tenga en cuenta los pocos ejemplos disponibles. Puede hacer esto pasando pesos de Keras para cada clase a través de un parámetro. Esto hará que el modelo "preste más atención" a ejemplos de una clase subrepresentada.

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

Entrena un modelo con pesos de clase

Ahora intente volver a entrenar y evaluar el modelo con ponderaciones de clase para ver cómo afecta eso a las predicciones.

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 3s 15ms/step - loss: 3.7596 - tp: 56.9890 - fp: 76.4286 - tn: 150769.4286 - fn: 199.7253 - accuracy: 0.9983 - precision: 0.4816 - recall: 0.2595 - auc: 0.7156 - val_loss: 0.0096 - val_tp: 2.0000 - val_fp: 1.0000 - val_tn: 45488.0000 - val_fn: 78.0000 - val_accuracy: 0.9983 - val_precision: 0.6667 - val_recall: 0.0250 - val_auc: 0.8999
Epoch 2/100
90/90 [==============================] - 1s 7ms/step - loss: 1.5426 - tp: 57.2308 - fp: 247.2088 - tn: 93722.7582 - fn: 113.3736 - accuracy: 0.9964 - precision: 0.1813 - recall: 0.2976 - auc: 0.8189 - val_loss: 0.0089 - val_tp: 54.0000 - val_fp: 18.0000 - val_tn: 45471.0000 - val_fn: 26.0000 - val_accuracy: 0.9990 - val_precision: 0.7500 - val_recall: 0.6750 - val_auc: 0.9306
Epoch 3/100
90/90 [==============================] - 1s 7ms/step - loss: 0.8711 - tp: 95.8352 - fp: 494.1209 - tn: 93479.8681 - fn: 70.7473 - accuracy: 0.9943 - precision: 0.1692 - recall: 0.6059 - auc: 0.8912 - val_loss: 0.0121 - val_tp: 66.0000 - val_fp: 32.0000 - val_tn: 45457.0000 - val_fn: 14.0000 - val_accuracy: 0.9990 - val_precision: 0.6735 - val_recall: 0.8250 - val_auc: 0.9426
Epoch 4/100
90/90 [==============================] - 1s 7ms/step - loss: 0.6835 - tp: 108.5165 - fp: 794.1648 - tn: 93183.5055 - fn: 54.3846 - accuracy: 0.9912 - precision: 0.1191 - recall: 0.6530 - auc: 0.8987 - val_loss: 0.0163 - val_tp: 67.0000 - val_fp: 54.0000 - val_tn: 45435.0000 - val_fn: 13.0000 - val_accuracy: 0.9985 - val_precision: 0.5537 - val_recall: 0.8375 - val_auc: 0.9556
Epoch 5/100
90/90 [==============================] - 1s 7ms/step - loss: 0.4713 - tp: 126.3626 - fp: 1149.3407 - tn: 92827.5275 - fn: 37.3407 - accuracy: 0.9878 - precision: 0.0992 - recall: 0.7828 - auc: 0.9329 - val_loss: 0.0214 - val_tp: 67.0000 - val_fp: 95.0000 - val_tn: 45394.0000 - val_fn: 13.0000 - val_accuracy: 0.9976 - val_precision: 0.4136 - val_recall: 0.8375 - val_auc: 0.9588
Epoch 6/100
90/90 [==============================] - 1s 7ms/step - loss: 0.4194 - tp: 125.7912 - fp: 1550.7253 - tn: 92430.8791 - fn: 33.1758 - accuracy: 0.9837 - precision: 0.0769 - recall: 0.7990 - auc: 0.9373 - val_loss: 0.0270 - val_tp: 67.0000 - val_fp: 147.0000 - val_tn: 45342.0000 - val_fn: 13.0000 - val_accuracy: 0.9965 - val_precision: 0.3131 - val_recall: 0.8375 - val_auc: 0.9626
Epoch 7/100
90/90 [==============================] - 1s 7ms/step - loss: 0.4226 - tp: 127.0659 - fp: 2000.6374 - tn: 91978.2857 - fn: 34.5824 - accuracy: 0.9788 - precision: 0.0567 - recall: 0.7672 - auc: 0.9351 - val_loss: 0.0348 - val_tp: 67.0000 - val_fp: 224.0000 - val_tn: 45265.0000 - val_fn: 13.0000 - val_accuracy: 0.9948 - val_precision: 0.2302 - val_recall: 0.8375 - val_auc: 0.9656
Epoch 8/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2990 - tp: 137.7473 - fp: 2564.6154 - tn: 91411.7363 - fn: 26.4725 - accuracy: 0.9729 - precision: 0.0528 - recall: 0.8483 - auc: 0.9609 - val_loss: 0.0457 - val_tp: 69.0000 - val_fp: 406.0000 - val_tn: 45083.0000 - val_fn: 11.0000 - val_accuracy: 0.9908 - val_precision: 0.1453 - val_recall: 0.8625 - val_auc: 0.9691
Epoch 9/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3165 - tp: 125.0330 - fp: 3192.7473 - tn: 90795.8462 - fn: 26.9451 - accuracy: 0.9662 - precision: 0.0375 - recall: 0.8237 - auc: 0.9518 - val_loss: 0.0568 - val_tp: 69.0000 - val_fp: 654.0000 - val_tn: 44835.0000 - val_fn: 11.0000 - val_accuracy: 0.9854 - val_precision: 0.0954 - val_recall: 0.8625 - val_auc: 0.9703
Epoch 10/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2917 - tp: 139.2527 - fp: 3709.6154 - tn: 90270.3626 - fn: 21.3407 - accuracy: 0.9606 - precision: 0.0356 - recall: 0.8714 - auc: 0.9546 - val_loss: 0.0640 - val_tp: 70.0000 - val_fp: 755.0000 - val_tn: 44734.0000 - val_fn: 10.0000 - val_accuracy: 0.9832 - val_precision: 0.0848 - val_recall: 0.8750 - val_auc: 0.9713
Epoch 11/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3109 - tp: 155.7363 - fp: 3838.3736 - tn: 90123.9670 - fn: 22.4945 - accuracy: 0.9590 - precision: 0.0415 - recall: 0.8730 - auc: 0.9555 - val_loss: 0.0703 - val_tp: 71.0000 - val_fp: 835.0000 - val_tn: 44654.0000 - val_fn: 9.0000 - val_accuracy: 0.9815 - val_precision: 0.0784 - val_recall: 0.8875 - val_auc: 0.9719
Epoch 12/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3007 - tp: 134.9121 - fp: 4045.0549 - tn: 89938.1758 - fn: 22.4286 - accuracy: 0.9566 - precision: 0.0320 - recall: 0.8712 - auc: 0.9494 - val_loss: 0.0756 - val_tp: 72.0000 - val_fp: 910.0000 - val_tn: 44579.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0733 - val_recall: 0.9000 - val_auc: 0.9716
Epoch 13/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2083 - tp: 143.2308 - fp: 4041.6593 - tn: 89938.9670 - fn: 16.7143 - accuracy: 0.9567 - precision: 0.0360 - recall: 0.9154 - auc: 0.9734 - val_loss: 0.0765 - val_tp: 72.0000 - val_fp: 916.0000 - val_tn: 44573.0000 - val_fn: 8.0000 - val_accuracy: 0.9797 - val_precision: 0.0729 - val_recall: 0.9000 - val_auc: 0.9726
Epoch 14/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2448 - tp: 143.6703 - fp: 4119.2967 - tn: 89857.6154 - fn: 19.9890 - accuracy: 0.9562 - precision: 0.0341 - recall: 0.8944 - auc: 0.9655 - val_loss: 0.0811 - val_tp: 72.0000 - val_fp: 992.0000 - val_tn: 44497.0000 - val_fn: 8.0000 - val_accuracy: 0.9781 - val_precision: 0.0677 - val_recall: 0.9000 - val_auc: 0.9724
Epoch 15/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2362 - tp: 141.0769 - fp: 4205.1429 - tn: 89776.9670 - fn: 17.3846 - accuracy: 0.9545 - precision: 0.0316 - recall: 0.8889 - auc: 0.9665 - val_loss: 0.0835 - val_tp: 72.0000 - val_fp: 1019.0000 - val_tn: 44470.0000 - val_fn: 8.0000 - val_accuracy: 0.9775 - val_precision: 0.0660 - val_recall: 0.9000 - val_auc: 0.9729
Epoch 16/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2100 - tp: 135.4945 - fp: 4278.8242 - tn: 89707.5495 - fn: 18.7033 - accuracy: 0.9542 - precision: 0.0289 - recall: 0.8980 - auc: 0.9717 - val_loss: 0.0922 - val_tp: 72.0000 - val_fp: 1117.0000 - val_tn: 44372.0000 - val_fn: 8.0000 - val_accuracy: 0.9753 - val_precision: 0.0606 - val_recall: 0.9000 - val_auc: 0.9728
Epoch 17/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2061 - tp: 147.9451 - fp: 4488.5604 - tn: 89486.7912 - fn: 17.2747 - accuracy: 0.9519 - precision: 0.0328 - recall: 0.9026 - auc: 0.9754 - val_loss: 0.0911 - val_tp: 72.0000 - val_fp: 1104.0000 - val_tn: 44385.0000 - val_fn: 8.0000 - val_accuracy: 0.9756 - val_precision: 0.0612 - val_recall: 0.9000 - val_auc: 0.9730
Epoch 18/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3032 - tp: 143.0989 - fp: 4367.9890 - tn: 89610.3956 - fn: 19.0879 - accuracy: 0.9529 - precision: 0.0312 - recall: 0.8782 - auc: 0.9486 - val_loss: 0.0878 - val_tp: 72.0000 - val_fp: 1037.0000 - val_tn: 44452.0000 - val_fn: 8.0000 - val_accuracy: 0.9771 - val_precision: 0.0649 - val_recall: 0.9000 - val_auc: 0.9761
Epoch 19/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2713 - tp: 148.5165 - fp: 4079.8462 - tn: 89894.0110 - fn: 18.1978 - accuracy: 0.9565 - precision: 0.0361 - recall: 0.8861 - auc: 0.9635 - val_loss: 0.0868 - val_tp: 72.0000 - val_fp: 1011.0000 - val_tn: 44478.0000 - val_fn: 8.0000 - val_accuracy: 0.9776 - val_precision: 0.0665 - val_recall: 0.9000 - val_auc: 0.9762
Epoch 20/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2623 - tp: 144.5055 - fp: 3991.8681 - tn: 89984.6484 - fn: 19.5495 - accuracy: 0.9575 - precision: 0.0332 - recall: 0.8837 - auc: 0.9587 - val_loss: 0.0888 - val_tp: 72.0000 - val_fp: 1030.0000 - val_tn: 44459.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0653 - val_recall: 0.9000 - val_auc: 0.9761
Epoch 21/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3103 - tp: 136.6154 - fp: 3966.2527 - tn: 90015.5934 - fn: 22.1099 - accuracy: 0.9577 - precision: 0.0309 - recall: 0.8330 - auc: 0.9373 - val_loss: 0.0886 - val_tp: 72.0000 - val_fp: 1010.0000 - val_tn: 44479.0000 - val_fn: 8.0000 - val_accuracy: 0.9777 - val_precision: 0.0665 - val_recall: 0.9000 - val_auc: 0.9764
Epoch 22/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2215 - tp: 141.6484 - fp: 3872.5055 - tn: 90108.2527 - fn: 18.1648 - accuracy: 0.9584 - precision: 0.0347 - recall: 0.8935 - auc: 0.9698 - val_loss: 0.0862 - val_tp: 72.0000 - val_fp: 972.0000 - val_tn: 44517.0000 - val_fn: 8.0000 - val_accuracy: 0.9785 - val_precision: 0.0690 - val_recall: 0.9000 - val_auc: 0.9776
Epoch 23/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2214 - tp: 141.6593 - fp: 3877.3626 - tn: 90102.6593 - fn: 18.8901 - accuracy: 0.9585 - precision: 0.0354 - recall: 0.8872 - auc: 0.9685 - val_loss: 0.0842 - val_tp: 72.0000 - val_fp: 938.0000 - val_tn: 44551.0000 - val_fn: 8.0000 - val_accuracy: 0.9792 - val_precision: 0.0713 - val_recall: 0.9000 - val_auc: 0.9777
Epoch 24/100
90/90 [==============================] - 1s 7ms/step - loss: 0.1873 - tp: 138.3956 - fp: 3647.7473 - tn: 90337.8681 - fn: 16.5604 - accuracy: 0.9611 - precision: 0.0340 - recall: 0.9088 - auc: 0.9743 - val_loss: 0.0843 - val_tp: 73.0000 - val_fp: 938.0000 - val_tn: 44551.0000 - val_fn: 7.0000 - val_accuracy: 0.9793 - val_precision: 0.0722 - val_recall: 0.9125 - val_auc: 0.9779
Epoch 25/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2038 - tp: 140.3407 - fp: 3553.5824 - tn: 90428.6813 - fn: 17.9670 - accuracy: 0.9620 - precision: 0.0383 - recall: 0.8946 - auc: 0.9739 - val_loss: 0.0854 - val_tp: 73.0000 - val_fp: 942.0000 - val_tn: 44547.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0719 - val_recall: 0.9125 - val_auc: 0.9777
Epoch 26/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2003 - tp: 151.4725 - fp: 3495.4725 - tn: 90479.3516 - fn: 14.2747 - accuracy: 0.9625 - precision: 0.0406 - recall: 0.9196 - auc: 0.9732 - val_loss: 0.0819 - val_tp: 73.0000 - val_fp: 895.0000 - val_tn: 44594.0000 - val_fn: 7.0000 - val_accuracy: 0.9802 - val_precision: 0.0754 - val_recall: 0.9125 - val_auc: 0.9778
Epoch 27/100
90/90 [==============================] - 1s 9ms/step - loss: 0.2111 - tp: 138.2857 - fp: 3438.2088 - tn: 90547.7253 - fn: 16.3516 - accuracy: 0.9635 - precision: 0.0382 - recall: 0.9021 - auc: 0.9690 - val_loss: 0.0865 - val_tp: 73.0000 - val_fp: 940.0000 - val_tn: 44549.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0721 - val_recall: 0.9125 - val_auc: 0.9775
Epoch 28/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2115 - tp: 146.6813 - fp: 3796.8022 - tn: 90178.5604 - fn: 18.5275 - accuracy: 0.9595 - precision: 0.0357 - recall: 0.8799 - auc: 0.9740 - val_loss: 0.0914 - val_tp: 73.0000 - val_fp: 996.0000 - val_tn: 44493.0000 - val_fn: 7.0000 - val_accuracy: 0.9780 - val_precision: 0.0683 - val_recall: 0.9125 - val_auc: 0.9774
Epoch 29/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2085 - tp: 142.0220 - fp: 3757.9341 - tn: 90222.3956 - fn: 18.2198 - accuracy: 0.9597 - precision: 0.0368 - recall: 0.8941 - auc: 0.9738 - val_loss: 0.0890 - val_tp: 73.0000 - val_fp: 962.0000 - val_tn: 44527.0000 - val_fn: 7.0000 - val_accuracy: 0.9787 - val_precision: 0.0705 - val_recall: 0.9125 - val_auc: 0.9775
Epoch 30/100
90/90 [==============================] - 1s 7ms/step - loss: 0.1559 - tp: 153.3077 - fp: 3571.6044 - tn: 90404.6264 - fn: 11.0330 - accuracy: 0.9617 - precision: 0.0420 - recall: 0.9426 - auc: 0.9811 - val_loss: 0.0852 - val_tp: 73.0000 - val_fp: 901.0000 - val_tn: 44588.0000 - val_fn: 7.0000 - val_accuracy: 0.9801 - val_precision: 0.0749 - val_recall: 0.9125 - val_auc: 0.9778
Epoch 31/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2635 - tp: 142.1209 - fp: 3672.6264 - tn: 90304.4725 - fn: 21.3516 - accuracy: 0.9613 - precision: 0.0376 - recall: 0.8706 - auc: 0.9641 - val_loss: 0.0930 - val_tp: 73.0000 - val_fp: 1011.0000 - val_tn: 44478.0000 - val_fn: 7.0000 - val_accuracy: 0.9777 - val_precision: 0.0673 - val_recall: 0.9125 - val_auc: 0.9773
Epoch 32/100
90/90 [==============================] - 1s 7ms/step - loss: 0.1858 - tp: 155.9121 - fp: 3632.0440 - tn: 90338.0769 - fn: 14.5385 - accuracy: 0.9609 - precision: 0.0410 - recall: 0.9196 - auc: 0.9791 - val_loss: 0.0897 - val_tp: 73.0000 - val_fp: 972.0000 - val_tn: 44517.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0699 - val_recall: 0.9125 - val_auc: 0.9774
Epoch 33/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2525 - tp: 152.4945 - fp: 3743.6264 - tn: 90226.1209 - fn: 18.3297 - accuracy: 0.9599 - precision: 0.0398 - recall: 0.8872 - auc: 0.9646 - val_loss: 0.0878 - val_tp: 73.0000 - val_fp: 930.0000 - val_tn: 44559.0000 - val_fn: 7.0000 - val_accuracy: 0.9794 - val_precision: 0.0728 - val_recall: 0.9125 - val_auc: 0.9774
Epoch 34/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2286 - tp: 142.4615 - fp: 3510.0659 - tn: 90469.1209 - fn: 18.9231 - accuracy: 0.9626 - precision: 0.0379 - recall: 0.8608 - auc: 0.9694 - val_loss: 0.0839 - val_tp: 73.0000 - val_fp: 889.0000 - val_tn: 44600.0000 - val_fn: 7.0000 - val_accuracy: 0.9803 - val_precision: 0.0759 - val_recall: 0.9125 - val_auc: 0.9773
Restoring model weights from the end of the best epoch.
Epoch 00034: early stopping

Consultar historial de entrenamiento

plot_metrics(weighted_history)

png

Evaluar métricas

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.08238842338323593
tp :  89.0
fp :  1186.0
tn :  55681.0
fn :  6.0
accuracy :  0.9790737628936768
precision :  0.06980392336845398
recall :  0.9368420839309692
auc :  0.98465895652771

Legitimate Transactions Detected (True Negatives):  55681
Legitimate Transactions Incorrectly Detected (False Positives):  1186
Fraudulent Transactions Missed (False Negatives):  6
Fraudulent Transactions Detected (True Positives):  89
Total Fraudulent Transactions:  95

png

Aquí puede ver que con las ponderaciones de clase, la exactitud y la precisión son menores porque hay más falsos positivos, pero a la inversa, la recuperación y el AUC son mayores porque el modelo también encontró más verdaderos positivos. A pesar de tener una precisión menor, este modelo tiene una mayor recuperación (e identifica más transacciones fraudulentas). Por supuesto, ambos tipos de error tienen un costo (tampoco querrá molestar a los usuarios marcando demasiadas transacciones legítimas como fraudulentas). Considere cuidadosamente las compensaciones entre estos diferentes tipos de errores para su aplicación.

Trazar la República de China

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa4fc124e48>

png

Sobremuestreo

Sobremuestrear la clase minoritaria

Un enfoque relacionado sería volver a muestrear el conjunto de datos sobremuestreando la clase minoritaria.

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

Usando NumPy

Puede equilibrar el conjunto de datos manualmente eligiendo el número correcto de índices aleatorios de los ejemplos positivos:

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181959, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363918, 29)

Usando tf.data

Si está utilizando tf.data la forma más fácil de producir ejemplos equilibrados es comenzar con un conjunto de datos positive y negative y fusionarlos. Consulte la guía tf.data para obtener más ejemplos.

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

Cada conjunto de datos proporciona pares (feature, label) :

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [-2.19450702e+00  1.14157653e+00 -1.53868568e+00 -3.35790031e-01
 -8.45956971e-01 -1.57966490e+00 -1.68436379e+00  2.25259298e-01
  2.62655847e-01 -3.51815449e+00  2.45653756e+00 -3.62123216e+00
 -1.02812838e+00 -5.00000000e+00  1.79960408e+00 -3.72675038e+00
 -5.00000000e+00 -2.22257320e+00 -1.24126989e-02 -9.14798839e-01
  7.33686754e-01 -9.34731229e-02 -1.74683429e+00  4.44394133e-01
 -3.98093307e-02 -1.99985035e+00 -2.27011783e+00  1.83260825e-03
  5.69563522e-01]

Label:  1

Combina los dos juntos usando experimental.sample_from_datasets :

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.5

Para usar este conjunto de datos, necesitará la cantidad de pasos por época.

La definición de "época" en este caso es menos clara. Supongamos que es la cantidad de lotes necesarios para ver cada ejemplo negativo una vez:

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

Entrene con los datos sobremuestreados

Ahora intente entrenar el modelo con el conjunto de datos remuestreados en lugar de usar ponderaciones de clase para ver cómo se comparan estos métodos.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks=[early_stopping],
    validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 8s 24ms/step - loss: 0.8118 - tp: 98159.0036 - fp: 34959.3333 - tn: 165642.9498 - fn: 44913.3728 - accuracy: 0.7582 - precision: 0.6573 - recall: 0.5935 - auc: 0.7814 - val_loss: 0.1785 - val_tp: 70.0000 - val_fp: 1177.0000 - val_tn: 44312.0000 - val_fn: 10.0000 - val_accuracy: 0.9740 - val_precision: 0.0561 - val_recall: 0.8750 - val_auc: 0.9722
Epoch 2/100
278/278 [==============================] - 6s 21ms/step - loss: 0.2142 - tp: 126477.2616 - fp: 8422.5305 - tn: 134479.1362 - fn: 17333.7312 - accuracy: 0.9075 - precision: 0.9341 - recall: 0.8774 - auc: 0.9705 - val_loss: 0.1007 - val_tp: 71.0000 - val_fp: 919.0000 - val_tn: 44570.0000 - val_fn: 9.0000 - val_accuracy: 0.9796 - val_precision: 0.0717 - val_recall: 0.8875 - val_auc: 0.9718
Epoch 3/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1652 - tp: 128160.7455 - fp: 6029.4875 - tn: 137362.8602 - fn: 15159.5663 - accuracy: 0.9255 - precision: 0.9544 - recall: 0.8936 - auc: 0.9831 - val_loss: 0.0775 - val_tp: 71.0000 - val_fp: 802.0000 - val_tn: 44687.0000 - val_fn: 9.0000 - val_accuracy: 0.9822 - val_precision: 0.0813 - val_recall: 0.8875 - val_auc: 0.9727
Epoch 4/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1451 - tp: 129069.5376 - fp: 5223.8925 - tn: 138326.3763 - fn: 14092.8530 - accuracy: 0.9322 - precision: 0.9611 - recall: 0.9007 - auc: 0.9872 - val_loss: 0.0676 - val_tp: 71.0000 - val_fp: 756.0000 - val_tn: 44733.0000 - val_fn: 9.0000 - val_accuracy: 0.9832 - val_precision: 0.0859 - val_recall: 0.8875 - val_auc: 0.9766
Epoch 5/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1328 - tp: 130130.0502 - fp: 5189.8996 - tn: 138364.5878 - fn: 13028.1219 - accuracy: 0.9360 - precision: 0.9611 - recall: 0.9085 - auc: 0.9895 - val_loss: 0.0616 - val_tp: 72.0000 - val_fp: 763.0000 - val_tn: 44726.0000 - val_fn: 8.0000 - val_accuracy: 0.9831 - val_precision: 0.0862 - val_recall: 0.9000 - val_auc: 0.9748
Epoch 6/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1230 - tp: 131300.7706 - fp: 4925.0179 - tn: 138645.5125 - fn: 11841.3584 - accuracy: 0.9413 - precision: 0.9635 - recall: 0.9170 - auc: 0.9912 - val_loss: 0.0566 - val_tp: 72.0000 - val_fp: 754.0000 - val_tn: 44735.0000 - val_fn: 8.0000 - val_accuracy: 0.9833 - val_precision: 0.0872 - val_recall: 0.9000 - val_auc: 0.9759
Epoch 7/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1150 - tp: 132477.4875 - fp: 4822.2509 - tn: 138449.6989 - fn: 10963.2222 - accuracy: 0.9445 - precision: 0.9647 - recall: 0.9228 - auc: 0.9925 - val_loss: 0.0516 - val_tp: 72.0000 - val_fp: 711.0000 - val_tn: 44778.0000 - val_fn: 8.0000 - val_accuracy: 0.9842 - val_precision: 0.0920 - val_recall: 0.9000 - val_auc: 0.9777
Epoch 8/100
278/278 [==============================] - 6s 23ms/step - loss: 0.1069 - tp: 132886.1828 - fp: 4656.9570 - tn: 138848.8029 - fn: 10320.7168 - accuracy: 0.9474 - precision: 0.9660 - recall: 0.9275 - auc: 0.9936 - val_loss: 0.0462 - val_tp: 71.0000 - val_fp: 667.0000 - val_tn: 44822.0000 - val_fn: 9.0000 - val_accuracy: 0.9852 - val_precision: 0.0962 - val_recall: 0.8875 - val_auc: 0.9695
Epoch 9/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1006 - tp: 133579.5950 - fp: 4598.8602 - tn: 138524.1039 - fn: 10010.1004 - accuracy: 0.9489 - precision: 0.9668 - recall: 0.9298 - auc: 0.9944 - val_loss: 0.0419 - val_tp: 71.0000 - val_fp: 639.0000 - val_tn: 44850.0000 - val_fn: 9.0000 - val_accuracy: 0.9858 - val_precision: 0.1000 - val_recall: 0.8875 - val_auc: 0.9703
Epoch 10/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0955 - tp: 134632.2903 - fp: 4540.9355 - tn: 138577.5520 - fn: 8961.8817 - accuracy: 0.9521 - precision: 0.9676 - recall: 0.9358 - auc: 0.9950 - val_loss: 0.0396 - val_tp: 71.0000 - val_fp: 634.0000 - val_tn: 44855.0000 - val_fn: 9.0000 - val_accuracy: 0.9859 - val_precision: 0.1007 - val_recall: 0.8875 - val_auc: 0.9664
Epoch 11/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0908 - tp: 140213.1900 - fp: 5993.7849 - tn: 136973.6380 - fn: 3532.0466 - accuracy: 0.9663 - precision: 0.9588 - recall: 0.9747 - auc: 0.9955 - val_loss: 0.0359 - val_tp: 71.0000 - val_fp: 596.0000 - val_tn: 44893.0000 - val_fn: 9.0000 - val_accuracy: 0.9867 - val_precision: 0.1064 - val_recall: 0.8875 - val_auc: 0.9669
Epoch 12/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0857 - tp: 140261.6918 - fp: 5979.1470 - tn: 137657.2294 - fn: 2814.5914 - accuracy: 0.9691 - precision: 0.9589 - recall: 0.9800 - auc: 0.9959 - val_loss: 0.0337 - val_tp: 71.0000 - val_fp: 580.0000 - val_tn: 44909.0000 - val_fn: 9.0000 - val_accuracy: 0.9871 - val_precision: 0.1091 - val_recall: 0.8875 - val_auc: 0.9677
Epoch 13/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0818 - tp: 140796.7133 - fp: 5945.5305 - tn: 137458.5520 - fn: 2511.8638 - accuracy: 0.9704 - precision: 0.9596 - recall: 0.9822 - auc: 0.9962 - val_loss: 0.0318 - val_tp: 71.0000 - val_fp: 558.0000 - val_tn: 44931.0000 - val_fn: 9.0000 - val_accuracy: 0.9876 - val_precision: 0.1129 - val_recall: 0.8875 - val_auc: 0.9682
Epoch 14/100
278/278 [==============================] - 7s 24ms/step - loss: 0.0793 - tp: 140997.9176 - fp: 6076.8746 - tn: 137562.6846 - fn: 2075.1828 - accuracy: 0.9714 - precision: 0.9586 - recall: 0.9853 - auc: 0.9964 - val_loss: 0.0303 - val_tp: 71.0000 - val_fp: 555.0000 - val_tn: 44934.0000 - val_fn: 9.0000 - val_accuracy: 0.9876 - val_precision: 0.1134 - val_recall: 0.8875 - val_auc: 0.9687
Epoch 15/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0759 - tp: 141966.7312 - fp: 6100.1147 - tn: 136957.0108 - fn: 1688.8029 - accuracy: 0.9729 - precision: 0.9589 - recall: 0.9883 - auc: 0.9966 - val_loss: 0.0292 - val_tp: 71.0000 - val_fp: 541.0000 - val_tn: 44948.0000 - val_fn: 9.0000 - val_accuracy: 0.9879 - val_precision: 0.1160 - val_recall: 0.8875 - val_auc: 0.9692
Epoch 16/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0727 - tp: 141879.4731 - fp: 6020.2007 - tn: 137272.9140 - fn: 1540.0717 - accuracy: 0.9736 - precision: 0.9594 - recall: 0.9891 - auc: 0.9968 - val_loss: 0.0276 - val_tp: 71.0000 - val_fp: 504.0000 - val_tn: 44985.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 0.1235 - val_recall: 0.8875 - val_auc: 0.9697
Epoch 17/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0704 - tp: 142140.7563 - fp: 6019.2258 - tn: 137225.3978 - fn: 1327.2796 - accuracy: 0.9745 - precision: 0.9597 - recall: 0.9907 - auc: 0.9969 - val_loss: 0.0269 - val_tp: 71.0000 - val_fp: 495.0000 - val_tn: 44994.0000 - val_fn: 9.0000 - val_accuracy: 0.9889 - val_precision: 0.1254 - val_recall: 0.8875 - val_auc: 0.9699
Restoring model weights from the end of the best epoch.
Epoch 00017: early stopping

Si el proceso de entrenamiento tuviera en cuenta el conjunto de datos completo en cada actualización de gradiente, este sobremuestreo sería básicamente idéntico a la ponderación de la clase.

Pero al entrenar el modelo por lotes, como lo hizo aquí, los datos sobremuestreados proporcionan una señal de gradiente más suave: en lugar de que cada ejemplo positivo se muestre en un lote con un gran peso, se muestran en muchos lotes diferentes cada vez con un pequeño peso.

Esta señal de gradiente más suave facilita el entrenamiento del modelo.

Consultar historial de entrenamiento

Tenga en cuenta que las distribuciones de métricas serán diferentes aquí, porque los datos de entrenamiento tienen una distribución totalmente diferente de los datos de validación y prueba.

plot_metrics(resampled_history)

png

Reentrenar

Debido a que el entrenamiento es más fácil con los datos balanceados, el procedimiento de entrenamiento anterior puede sobreajustarse rápidamente.

Así que divida las épocas para devolver las callbacks.EarlyStopping Temprano Detener un control más callbacks.EarlyStopping sobre cuándo dejar de entrenar.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch=20,
    epochs=10*EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_ds))
Epoch 1/1000
20/20 [==============================] - 3s 59ms/step - loss: 1.6583 - tp: 2292.0000 - fp: 4139.7619 - tn: 52605.9524 - fn: 8961.7619 - accuracy: 0.8197 - precision: 0.3317 - recall: 0.1958 - auc: 0.7857 - val_loss: 0.5873 - val_tp: 9.0000 - val_fp: 13094.0000 - val_tn: 32395.0000 - val_fn: 71.0000 - val_accuracy: 0.7111 - val_precision: 6.8687e-04 - val_recall: 0.1125 - val_auc: 0.2616
Epoch 2/1000
20/20 [==============================] - 1s 27ms/step - loss: 1.0305 - tp: 4655.6667 - fp: 3775.2857 - tn: 7528.4286 - fn: 6471.0952 - accuracy: 0.5318 - precision: 0.5347 - recall: 0.3960 - auc: 0.4700 - val_loss: 0.5752 - val_tp: 45.0000 - val_fp: 12493.0000 - val_tn: 32996.0000 - val_fn: 35.0000 - val_accuracy: 0.7251 - val_precision: 0.0036 - val_recall: 0.5625 - val_auc: 0.7427
Epoch 3/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.6911 - tp: 7407.9048 - fp: 3609.9048 - tn: 7705.3333 - fn: 3707.3333 - accuracy: 0.6637 - precision: 0.6637 - recall: 0.6487 - auc: 0.6971 - val_loss: 0.5367 - val_tp: 73.0000 - val_fp: 10392.0000 - val_tn: 35097.0000 - val_fn: 7.0000 - val_accuracy: 0.7718 - val_precision: 0.0070 - val_recall: 0.9125 - val_auc: 0.9205
Epoch 4/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.5257 - tp: 8758.5238 - fp: 3223.3333 - tn: 8031.5714 - fn: 2417.0476 - accuracy: 0.7432 - precision: 0.7273 - recall: 0.7793 - auc: 0.8200 - val_loss: 0.4801 - val_tp: 73.0000 - val_fp: 7752.0000 - val_tn: 37737.0000 - val_fn: 7.0000 - val_accuracy: 0.8297 - val_precision: 0.0093 - val_recall: 0.9125 - val_auc: 0.9436
Epoch 5/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.4499 - tp: 9241.0952 - fp: 2761.1905 - tn: 8463.3333 - fn: 1964.8571 - accuracy: 0.7862 - precision: 0.7667 - recall: 0.8219 - auc: 0.8704 - val_loss: 0.4242 - val_tp: 72.0000 - val_fp: 5662.0000 - val_tn: 39827.0000 - val_fn: 8.0000 - val_accuracy: 0.8756 - val_precision: 0.0126 - val_recall: 0.9000 - val_auc: 0.9524
Epoch 6/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.4029 - tp: 9361.0000 - fp: 2311.8571 - tn: 8961.6190 - fn: 1796.0000 - accuracy: 0.8140 - precision: 0.7965 - recall: 0.8389 - auc: 0.8968 - val_loss: 0.3750 - val_tp: 71.0000 - val_fp: 4105.0000 - val_tn: 41384.0000 - val_fn: 9.0000 - val_accuracy: 0.9097 - val_precision: 0.0170 - val_recall: 0.8875 - val_auc: 0.9583
Epoch 7/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.3601 - tp: 9456.8571 - fp: 1812.6667 - tn: 9376.6190 - fn: 1784.3333 - accuracy: 0.8386 - precision: 0.8376 - recall: 0.8413 - auc: 0.9181 - val_loss: 0.3324 - val_tp: 70.0000 - val_fp: 3047.0000 - val_tn: 42442.0000 - val_fn: 10.0000 - val_accuracy: 0.9329 - val_precision: 0.0225 - val_recall: 0.8750 - val_auc: 0.9626
Epoch 8/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.3373 - tp: 9506.3810 - fp: 1623.5714 - tn: 9568.3333 - fn: 1732.1905 - accuracy: 0.8490 - precision: 0.8520 - recall: 0.8453 - auc: 0.9266 - val_loss: 0.2959 - val_tp: 69.0000 - val_fp: 2358.0000 - val_tn: 43131.0000 - val_fn: 11.0000 - val_accuracy: 0.9480 - val_precision: 0.0284 - val_recall: 0.8625 - val_auc: 0.9660
Epoch 9/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.3120 - tp: 9635.0000 - fp: 1347.7619 - tn: 9797.4286 - fn: 1650.2857 - accuracy: 0.8662 - precision: 0.8776 - recall: 0.8534 - auc: 0.9366 - val_loss: 0.2670 - val_tp: 69.0000 - val_fp: 1943.0000 - val_tn: 43546.0000 - val_fn: 11.0000 - val_accuracy: 0.9571 - val_precision: 0.0343 - val_recall: 0.8625 - val_auc: 0.9689
Epoch 10/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2923 - tp: 9632.5714 - fp: 1221.3333 - tn: 9978.5714 - fn: 1598.0000 - accuracy: 0.8725 - precision: 0.8864 - recall: 0.8559 - auc: 0.9446 - val_loss: 0.2417 - val_tp: 69.0000 - val_fp: 1609.0000 - val_tn: 43880.0000 - val_fn: 11.0000 - val_accuracy: 0.9644 - val_precision: 0.0411 - val_recall: 0.8625 - val_auc: 0.9706
Epoch 11/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2765 - tp: 9632.7619 - fp: 1069.8095 - tn: 10155.4762 - fn: 1572.4286 - accuracy: 0.8806 - precision: 0.8976 - recall: 0.8584 - auc: 0.9505 - val_loss: 0.2211 - val_tp: 69.0000 - val_fp: 1424.0000 - val_tn: 44065.0000 - val_fn: 11.0000 - val_accuracy: 0.9685 - val_precision: 0.0462 - val_recall: 0.8625 - val_auc: 0.9716
Epoch 12/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2624 - tp: 9721.6190 - fp: 949.2857 - tn: 10243.6190 - fn: 1515.9524 - accuracy: 0.8901 - precision: 0.9102 - recall: 0.8658 - auc: 0.9556 - val_loss: 0.2038 - val_tp: 69.0000 - val_fp: 1299.0000 - val_tn: 44190.0000 - val_fn: 11.0000 - val_accuracy: 0.9713 - val_precision: 0.0504 - val_recall: 0.8625 - val_auc: 0.9722
Epoch 13/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2491 - tp: 9862.2381 - fp: 851.1905 - tn: 10195.6667 - fn: 1521.3810 - accuracy: 0.8933 - precision: 0.9207 - recall: 0.8656 - auc: 0.9597 - val_loss: 0.1904 - val_tp: 70.0000 - val_fp: 1230.0000 - val_tn: 44259.0000 - val_fn: 10.0000 - val_accuracy: 0.9728 - val_precision: 0.0538 - val_recall: 0.8750 - val_auc: 0.9726
Epoch 14/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2384 - tp: 9752.6190 - fp: 806.5714 - tn: 10438.3333 - fn: 1432.9524 - accuracy: 0.9006 - precision: 0.9238 - recall: 0.8724 - auc: 0.9637 - val_loss: 0.1781 - val_tp: 70.0000 - val_fp: 1186.0000 - val_tn: 44303.0000 - val_fn: 10.0000 - val_accuracy: 0.9738 - val_precision: 0.0557 - val_recall: 0.8750 - val_auc: 0.9724
Epoch 15/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2332 - tp: 9727.1905 - fp: 787.1905 - tn: 10508.9524 - fn: 1407.1429 - accuracy: 0.9024 - precision: 0.9255 - recall: 0.8737 - auc: 0.9651 - val_loss: 0.1664 - val_tp: 70.0000 - val_fp: 1130.0000 - val_tn: 44359.0000 - val_fn: 10.0000 - val_accuracy: 0.9750 - val_precision: 0.0583 - val_recall: 0.8750 - val_auc: 0.9728
Epoch 16/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2242 - tp: 9812.1905 - fp: 731.9524 - tn: 10481.3333 - fn: 1405.0000 - accuracy: 0.9048 - precision: 0.9310 - recall: 0.8745 - auc: 0.9677 - val_loss: 0.1561 - val_tp: 70.0000 - val_fp: 1085.0000 - val_tn: 44404.0000 - val_fn: 10.0000 - val_accuracy: 0.9760 - val_precision: 0.0606 - val_recall: 0.8750 - val_auc: 0.9725
Epoch 17/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2197 - tp: 9864.8571 - fp: 691.0000 - tn: 10487.2381 - fn: 1387.3810 - accuracy: 0.9072 - precision: 0.9350 - recall: 0.8766 - auc: 0.9690 - val_loss: 0.1475 - val_tp: 70.0000 - val_fp: 1042.0000 - val_tn: 44447.0000 - val_fn: 10.0000 - val_accuracy: 0.9769 - val_precision: 0.0629 - val_recall: 0.8750 - val_auc: 0.9724
Epoch 18/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2110 - tp: 9788.9524 - fp: 651.8571 - tn: 10585.9524 - fn: 1403.7143 - accuracy: 0.9087 - precision: 0.9392 - recall: 0.8742 - auc: 0.9716 - val_loss: 0.1401 - val_tp: 70.0000 - val_fp: 1020.0000 - val_tn: 44469.0000 - val_fn: 10.0000 - val_accuracy: 0.9774 - val_precision: 0.0642 - val_recall: 0.8750 - val_auc: 0.9718
Epoch 19/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2035 - tp: 9779.1905 - fp: 619.4286 - tn: 10699.0952 - fn: 1332.7619 - accuracy: 0.9117 - precision: 0.9402 - recall: 0.8784 - auc: 0.9733 - val_loss: 0.1340 - val_tp: 70.0000 - val_fp: 1010.0000 - val_tn: 44479.0000 - val_fn: 10.0000 - val_accuracy: 0.9776 - val_precision: 0.0648 - val_recall: 0.8750 - val_auc: 0.9721
Epoch 20/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1993 - tp: 9782.3333 - fp: 592.7619 - tn: 10738.1905 - fn: 1317.1905 - accuracy: 0.9156 - precision: 0.9430 - recall: 0.8823 - auc: 0.9748 - val_loss: 0.1282 - val_tp: 70.0000 - val_fp: 999.0000 - val_tn: 44490.0000 - val_fn: 10.0000 - val_accuracy: 0.9779 - val_precision: 0.0655 - val_recall: 0.8750 - val_auc: 0.9721
Epoch 21/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1995 - tp: 9862.0476 - fp: 625.4762 - tn: 10605.0952 - fn: 1337.8571 - accuracy: 0.9119 - precision: 0.9394 - recall: 0.8805 - auc: 0.9747 - val_loss: 0.1228 - val_tp: 70.0000 - val_fp: 965.0000 - val_tn: 44524.0000 - val_fn: 10.0000 - val_accuracy: 0.9786 - val_precision: 0.0676 - val_recall: 0.8750 - val_auc: 0.9718
Epoch 22/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1936 - tp: 9804.0476 - fp: 582.0952 - tn: 10755.0000 - fn: 1289.3333 - accuracy: 0.9168 - precision: 0.9433 - recall: 0.8842 - auc: 0.9764 - val_loss: 0.1179 - val_tp: 70.0000 - val_fp: 944.0000 - val_tn: 44545.0000 - val_fn: 10.0000 - val_accuracy: 0.9791 - val_precision: 0.0690 - val_recall: 0.8750 - val_auc: 0.9713
Epoch 23/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1889 - tp: 9952.6667 - fp: 552.3333 - tn: 10589.3810 - fn: 1336.0952 - accuracy: 0.9155 - precision: 0.9472 - recall: 0.8812 - auc: 0.9774 - val_loss: 0.1146 - val_tp: 70.0000 - val_fp: 942.0000 - val_tn: 44547.0000 - val_fn: 10.0000 - val_accuracy: 0.9791 - val_precision: 0.0692 - val_recall: 0.8750 - val_auc: 0.9715
Epoch 24/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.1828 - tp: 9880.4286 - fp: 540.2381 - tn: 10753.0000 - fn: 1256.8095 - accuracy: 0.9207 - precision: 0.9483 - recall: 0.8883 - auc: 0.9787 - val_loss: 0.1116 - val_tp: 71.0000 - val_fp: 961.0000 - val_tn: 44528.0000 - val_fn: 9.0000 - val_accuracy: 0.9787 - val_precision: 0.0688 - val_recall: 0.8875 - val_auc: 0.9712
Epoch 25/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1829 - tp: 9871.3333 - fp: 555.4286 - tn: 10747.9048 - fn: 1255.8095 - accuracy: 0.9200 - precision: 0.9470 - recall: 0.8879 - auc: 0.9791 - val_loss: 0.1082 - val_tp: 71.0000 - val_fp: 953.0000 - val_tn: 44536.0000 - val_fn: 9.0000 - val_accuracy: 0.9789 - val_precision: 0.0693 - val_recall: 0.8875 - val_auc: 0.9713
Restoring model weights from the end of the best epoch.
Epoch 00025: early stopping

Vuelva a verificar el historial de entrenamiento

plot_metrics(resampled_history)

png

Evaluar métricas

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
loss :  0.16524264216423035
tp :  91.0
fp :  1376.0
tn :  55491.0
fn :  4.0
accuracy :  0.9757733345031738
precision :  0.06203135475516319
recall :  0.9578947424888611
auc :  0.9829339385032654

Legitimate Transactions Detected (True Negatives):  55491
Legitimate Transactions Incorrectly Detected (False Positives):  1376
Fraudulent Transactions Missed (False Negatives):  4
Fraudulent Transactions Detected (True Positives):  91
Total Fraudulent Transactions:  95

png

Trazar la República de China

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa0682848d0>

png

Aplicando este tutorial a su problema

La clasificación de datos desequilibrada es una tarea intrínsecamente difícil, ya que hay muy pocas muestras de las que aprender. Siempre debe comenzar con los datos primero y hacer todo lo posible para recopilar tantas muestras como sea posible y pensar en profundidad sobre qué características pueden ser relevantes para que el modelo pueda aprovechar al máximo su clase minoritaria. En algún momento, su modelo puede tener dificultades para mejorar y producir los resultados que desea, por lo que es importante tener en cuenta el contexto de su problema y las compensaciones entre los diferentes tipos de errores.