![]() | ![]() | ![]() | ![]() |
Este tutorial demuestra cómo clasificar un conjunto de datos altamente desequilibrado en el que la cantidad de ejemplos en una clase supera en gran medida a los ejemplos en otra. Trabajará con el conjunto de datos de detección de fraude de tarjetas de crédito alojado en Kaggle. El objetivo es detectar tan solo 492 transacciones fraudulentas de un total de 284.807 transacciones. Utilizará Keras para definir el modelo y los pesos de clase para ayudar al modelo a aprender de los datos desequilibrados. .
Este tutorial contiene código completo para:
- Cargue un archivo CSV usando Pandas.
- Cree conjuntos de entrenamiento, validación y prueba.
- Defina y entrene un modelo usando Keras (incluido el establecimiento de pesos de clase).
- Evalúe el modelo utilizando varias métricas (incluida la precisión y la recuperación).
- Pruebe técnicas comunes para lidiar con datos desequilibrados como:
- Ponderación de clase
- Sobremuestreo
Preparar
import tensorflow as tf
from tensorflow import keras
import os
import tempfile
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
Procesamiento y exploración de datos
Descargue el conjunto de datos de fraude de tarjetas de crédito de Kaggle
Pandas es una biblioteca de Python con muchas utilidades útiles para cargar y trabajar con datos estructurados y se puede usar para descargar archivos CSV en un marco de datos.
file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
Examinar el desequilibrio de la etiqueta de clase
Veamos el desequilibrio del conjunto de datos:
neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n Total: {}\n Positive: {} ({:.2f}% of total)\n'.format(
total, pos, 100 * pos / total))
Examples: Total: 284807 Positive: 492 (0.17% of total)
Esto muestra la pequeña fracción de muestras positivas.
Limpiar, dividir y normalizar los datos
Los datos sin procesar tienen algunos problemas. Primero, las columnas Time
y Amount
son demasiado variables para usarlas directamente. Suelta la columna de Time
(ya que no está claro qué significa) y toma el registro de la columna de Amount
para reducir su rango.
cleaned_df = raw_df.copy()
# You don't want the `Time` column.
cleaned_df.pop('Time')
# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)
Divida el conjunto de datos en conjuntos de entrenamiento, validación y prueba. El conjunto de validación se utiliza durante el ajuste del modelo para evaluar la pérdida y cualquier métrica; sin embargo, el modelo no se ajusta a estos datos. El conjunto de prueba no se usa por completo durante la fase de entrenamiento y solo se usa al final para evaluar qué tan bien se generaliza el modelo a nuevos datos. Esto es especialmente importante con conjuntos de datos desequilibrados donde el sobreajuste es una preocupación significativa por la falta de datos de entrenamiento.
# Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)
# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))
train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)
Normalice las funciones de entrada con sklearn StandardScaler. Esto establecerá la media en 0 y la desviación estándar en 1.
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)
train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)
print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)
print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,) Validation labels shape: (45569,) Test labels shape: (56962,) Training features shape: (182276, 29) Validation features shape: (45569, 29) Test features shape: (56962, 29)
Mira la distribución de datos
A continuación, compare las distribuciones de los ejemplos positivos y negativos de algunas características. Buenas preguntas que debe hacerse en este punto son:
- ¿Tienen sentido estas distribuciones?
- Si. Ha normalizado la entrada y estos se concentran principalmente en el rango
+/- 2
.
- Si. Ha normalizado la entrada y estos se concentran principalmente en el rango
- ¿Puedes ver la diferencia entre las distribuciones?
- Sí, los ejemplos positivos contienen una tasa mucho más alta de valores extremos.
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)
sns.jointplot(pos_df['V5'], pos_df['V6'],
kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")
sns.jointplot(neg_df['V5'], neg_df['V6'],
kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning /home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning
Definir el modelo y las métricas
Defina una función que cree una red neuronal simple con una capa oculta densamente conectada, una capa de abandono para reducir el sobreajuste y una capa sigmoidea de salida que devuelva la probabilidad de que una transacción sea fraudulenta:
METRICS = [
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.BinaryAccuracy(name='accuracy'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name='auc'),
]
def make_model(metrics=METRICS, output_bias=None):
if output_bias is not None:
output_bias = tf.keras.initializers.Constant(output_bias)
model = keras.Sequential([
keras.layers.Dense(
16, activation='relu',
input_shape=(train_features.shape[-1],)),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid',
bias_initializer=output_bias),
])
model.compile(
optimizer=keras.optimizers.Adam(lr=1e-3),
loss=keras.losses.BinaryCrossentropy(),
metrics=metrics)
return model
Comprender métricas útiles
Tenga en cuenta que hay algunas métricas definidas anteriormente que el modelo puede calcular y que serán útiles al evaluar el rendimiento.
- Los falsos negativos y los falsos positivos son muestras que se clasificaron incorrectamente
- Los verdaderos negativos y verdaderos positivos son muestras que se clasificaron correctamente
- La precisión es el porcentaje de ejemplos clasificados correctamente> $ \ frac {\ text {muestras verdaderas}} {\ text {muestras totales}} $
- La precisión es el porcentaje de positivos pronosticados que fueron clasificados correctamente> $ \ frac {\ text {verdaderos positivos}} {\ text {verdaderos positivos + falsos positivos}} $
- El recuerdo es el porcentaje de positivos reales que se clasificaron correctamente> $ \ frac {\ text {verdaderos positivos}} {\ text {verdaderos positivos + falsos negativos}} $
- AUC se refiere al área bajo la curva de una curva característica de funcionamiento del receptor (ROC-AUC). Esta métrica es igual a la probabilidad de que un clasificador clasifique una muestra aleatoria positiva más alta que una muestra aleatoria negativa.
Lee mas:
Modelo de línea de base
Construye el modelo
Ahora cree y entrene su modelo usando la función que se definió anteriormente. Tenga en cuenta que el modelo se ajusta con un tamaño de lote mayor que el predeterminado de 2048, esto es importante para garantizar que cada lote tenga una probabilidad decente de contener algunas muestras positivas. Si el tamaño del lote fuera demasiado pequeño, probablemente no tendrían transacciones fraudulentas de las que aprender.
EPOCHS = 100
BATCH_SIZE = 2048
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_auc',
verbose=1,
patience=10,
mode='max',
restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 480 _________________________________________________________________ dropout (Dropout) (None, 16) 0 _________________________________________________________________ dense_1 (Dense) (None, 1) 17 ================================================================= Total params: 497 Trainable params: 497 Non-trainable params: 0 _________________________________________________________________
Prueba de ejecución del modelo:
model.predict(train_features[:10])
array([[0.6983068 ], [0.7546284 ], [0.73785573], [0.7908986 ], [0.51232255], [0.752192 ], [0.7387281 ], [0.9410955 ], [0.809352 ], [0.6911539 ]], dtype=float32)
Opcional: establezca el sesgo inicial correcto.
Estas suposiciones iniciales no son buenas. Sabes que el conjunto de datos está desequilibrado. Establezca el sesgo de la capa de salida para reflejar eso (Ver: Una receta para entrenar redes neuronales: "init well" ). Esto puede ayudar con la convergencia inicial.
Con la inicialización de sesgo predeterminada, la pérdida debe ser de aproximadamente math.log(2) = 0.69314
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 1.5998
El sesgo correcto para establecer se puede derivar de:
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])
Establezca eso como el sesgo inicial, y el modelo dará conjeturas iniciales mucho más razonables.
Debe estar cerca de: pos/total = 0.0018
model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
array([[0.00168876], [0.00081124], [0.00087036], [0.00241473], [0.00133016], [0.00121771], [0.00079989], [0.00079692], [0.00257652], [0.00104385]], dtype=float32)
Con esta inicialización, la pérdida inicial debería ser aproximadamente:
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0174
Esta pérdida inicial es aproximadamente 50 veces menor que si hubiera sido con una inicialización ingenua.
De esta manera, el modelo no necesita pasar las primeras épocas simplemente aprendiendo que los ejemplos positivos son poco probables. Esto también facilita la lectura de gráficos de la pérdida durante el entrenamiento.
Punto de control de los pesos iniciales
Para que las distintas ejecuciones de entrenamiento sean más comparables, mantenga los pesos de este modelo inicial en un archivo de punto de control y cárguelos en cada modelo antes del entrenamiento.
initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)
Confirme que la corrección de sesgo ayuda
Antes de continuar, confirme rápidamente que la inicialización cuidadosa del sesgo realmente ayudó.
Entrene el modelo durante 20 épocas, con y sin esta inicialización cuidadosa, y compare las pérdidas:
model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=20,
validation_data=(val_features, val_labels),
verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=20,
validation_data=(val_features, val_labels),
verbose=0)
def plot_loss(history, label, n):
# Use a log scale on y-axis to show the wide range of values.
plt.semilogy(history.epoch, history.history['loss'],
color=colors[n], label='Train ' + label)
plt.semilogy(history.epoch, history.history['val_loss'],
color=colors[n], label='Val ' + label,
linestyle="--")
plt.xlabel('Epoch')
plt.ylabel('Loss')
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)
La figura anterior lo deja claro: en términos de pérdida de validación, en este problema, esta inicialización cuidadosa da una clara ventaja.
Entrena el modelo
model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels))
Epoch 1/100 90/90 [==============================] - 3s 15ms/step - loss: 0.0186 - tp: 64.0000 - fp: 25.1978 - tn: 139431.9780 - fn: 188.3956 - accuracy: 0.9985 - precision: 0.7214 - recall: 0.3037 - auc: 0.6752 - val_loss: 0.0109 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 45489.0000 - val_fn: 80.0000 - val_accuracy: 0.9982 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.6977 Epoch 2/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0085 - tp: 25.4505 - fp: 5.6703 - tn: 93973.1538 - fn: 136.2967 - accuracy: 0.9985 - precision: 0.6137 - recall: 0.1108 - auc: 0.8194 - val_loss: 0.0051 - val_tp: 27.0000 - val_fp: 9.0000 - val_tn: 45480.0000 - val_fn: 53.0000 - val_accuracy: 0.9986 - val_precision: 0.7500 - val_recall: 0.3375 - val_auc: 0.9184 Epoch 3/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0057 - tp: 68.4505 - fp: 9.9231 - tn: 93970.2198 - fn: 91.9780 - accuracy: 0.9989 - precision: 0.8896 - recall: 0.4278 - auc: 0.9133 - val_loss: 0.0043 - val_tp: 44.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 36.0000 - val_accuracy: 0.9990 - val_precision: 0.8148 - val_recall: 0.5500 - val_auc: 0.9185 Epoch 4/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0051 - tp: 86.9451 - fp: 17.4505 - tn: 93959.1429 - fn: 77.0330 - accuracy: 0.9990 - precision: 0.8422 - recall: 0.5397 - auc: 0.9191 - val_loss: 0.0040 - val_tp: 51.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 29.0000 - val_accuracy: 0.9991 - val_precision: 0.8361 - val_recall: 0.6375 - val_auc: 0.9248 Epoch 5/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0053 - tp: 83.8571 - fp: 13.0110 - tn: 93965.0879 - fn: 78.6154 - accuracy: 0.9990 - precision: 0.8890 - recall: 0.5114 - auc: 0.9212 - val_loss: 0.0039 - val_tp: 55.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 25.0000 - val_accuracy: 0.9992 - val_precision: 0.8462 - val_recall: 0.6875 - val_auc: 0.9247 Epoch 6/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0049 - tp: 83.9890 - fp: 11.9451 - tn: 93976.1648 - fn: 68.4725 - accuracy: 0.9992 - precision: 0.8793 - recall: 0.5693 - auc: 0.9022 - val_loss: 0.0038 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8406 - val_recall: 0.7250 - val_auc: 0.9247 Epoch 7/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0053 - tp: 89.0549 - fp: 19.6484 - tn: 93958.4835 - fn: 73.3846 - accuracy: 0.9991 - precision: 0.8187 - recall: 0.5711 - auc: 0.8824 - val_loss: 0.0036 - val_tp: 53.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 27.0000 - val_accuracy: 0.9992 - val_precision: 0.8413 - val_recall: 0.6625 - val_auc: 0.9309 Epoch 8/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0040 - tp: 92.4396 - fp: 11.4835 - tn: 93970.6374 - fn: 66.0110 - accuracy: 0.9992 - precision: 0.9188 - recall: 0.5617 - auc: 0.9298 - val_loss: 0.0036 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8529 - val_recall: 0.7250 - val_auc: 0.9309 Epoch 9/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 100.4396 - fp: 14.7802 - tn: 93963.6044 - fn: 61.7473 - accuracy: 0.9992 - precision: 0.8725 - recall: 0.6228 - auc: 0.9167 - val_loss: 0.0035 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8406 - val_recall: 0.7250 - val_auc: 0.9247 Epoch 10/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0046 - tp: 94.8132 - fp: 12.7253 - tn: 93958.3187 - fn: 74.7143 - accuracy: 0.9991 - precision: 0.8935 - recall: 0.5710 - auc: 0.9196 - val_loss: 0.0035 - val_tp: 61.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8472 - val_recall: 0.7625 - val_auc: 0.9247 Epoch 11/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0047 - tp: 97.0440 - fp: 16.3956 - tn: 93958.9231 - fn: 68.2088 - accuracy: 0.9991 - precision: 0.8591 - recall: 0.5946 - auc: 0.9160 - val_loss: 0.0034 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8529 - val_recall: 0.7250 - val_auc: 0.9247 Epoch 12/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0040 - tp: 101.0769 - fp: 13.1868 - tn: 93958.7143 - fn: 67.5934 - accuracy: 0.9992 - precision: 0.8799 - recall: 0.6127 - auc: 0.9202 - val_loss: 0.0034 - val_tp: 61.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8472 - val_recall: 0.7625 - val_auc: 0.9247 Epoch 13/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0043 - tp: 98.0769 - fp: 16.0330 - tn: 93961.6703 - fn: 64.7912 - accuracy: 0.9992 - precision: 0.8536 - recall: 0.6112 - auc: 0.9154 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 9.0000 - val_tn: 45480.0000 - val_fn: 21.0000 - val_accuracy: 0.9993 - val_precision: 0.8676 - val_recall: 0.7375 - val_auc: 0.9247 Epoch 14/100 90/90 [==============================] - 1s 9ms/step - loss: 0.0050 - tp: 93.5495 - fp: 15.4286 - tn: 93961.4615 - fn: 70.1319 - accuracy: 0.9991 - precision: 0.8590 - recall: 0.5563 - auc: 0.8916 - val_loss: 0.0033 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247 Epoch 15/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0042 - tp: 90.2198 - fp: 15.4286 - tn: 93968.0989 - fn: 66.8242 - accuracy: 0.9992 - precision: 0.8524 - recall: 0.5813 - auc: 0.9270 - val_loss: 0.0033 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247 Epoch 16/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 96.4835 - fp: 14.7363 - tn: 93960.4396 - fn: 68.9121 - accuracy: 0.9991 - precision: 0.8754 - recall: 0.5727 - auc: 0.9218 - val_loss: 0.0033 - val_tp: 62.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8611 - val_recall: 0.7750 - val_auc: 0.9247 Epoch 17/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0044 - tp: 98.7033 - fp: 16.4835 - tn: 93958.9231 - fn: 66.4615 - accuracy: 0.9991 - precision: 0.8674 - recall: 0.5985 - auc: 0.9108 - val_loss: 0.0032 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247 Restoring model weights from the end of the best epoch. Epoch 00017: early stopping
Consultar historial de entrenamiento
En esta sección, producirá gráficos de la precisión y pérdida de su modelo en el conjunto de entrenamiento y validación. Estos son útiles para verificar el sobreajuste, sobre el cual puede obtener más información en este tutorial .
Además, puede producir estos gráficos para cualquiera de las métricas que creó anteriormente. Los falsos negativos se incluyen como ejemplo.
def plot_metrics(history):
metrics = ['loss', 'auc', 'precision', 'recall']
for n, metric in enumerate(metrics):
name = metric.replace("_"," ").capitalize()
plt.subplot(2,2,n+1)
plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
plt.plot(history.epoch, history.history['val_'+metric],
color=colors[0], linestyle="--", label='Val')
plt.xlabel('Epoch')
plt.ylabel(name)
if metric == 'loss':
plt.ylim([0, plt.ylim()[1]])
elif metric == 'auc':
plt.ylim([0.8,1])
else:
plt.ylim([0,1])
plt.legend()
plot_metrics(baseline_history)
Evaluar métricas
Puede utilizar una matriz de confusión para resumir las etiquetas reales frente a las predichas, donde el eje X es la etiqueta predicha y el eje Y es la etiqueta real.
train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
cm = confusion_matrix(labels, predictions > p)
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt="d")
plt.title('Confusion matrix @{:.2f}'.format(p))
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
print('Total Fraudulent Transactions: ', np.sum(cm[1]))
Evalúe su modelo en el conjunto de datos de prueba y muestre los resultados de las métricas que creó anteriormente.
baseline_results = model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_baseline)
loss : 0.0034731989726424217 tp : 56.0 fp : 12.0 tn : 56855.0 fn : 39.0 accuracy : 0.9991046786308289 precision : 0.8235294222831726 recall : 0.5894736647605896 auc : 0.9418253898620605 Legitimate Transactions Detected (True Negatives): 56855 Legitimate Transactions Incorrectly Detected (False Positives): 12 Fraudulent Transactions Missed (False Negatives): 39 Fraudulent Transactions Detected (True Positives): 56 Total Fraudulent Transactions: 95
Si el modelo hubiera predicho todo perfectamente, esta sería una matriz diagonal donde los valores fuera de la diagonal principal, que indican predicciones incorrectas, serían cero. En este caso, la matriz muestra que tiene relativamente pocos falsos positivos, lo que significa que hubo relativamente pocas transacciones legítimas que se marcaron incorrectamente. Sin embargo, es probable que desee tener incluso menos falsos negativos a pesar del costo de aumentar el número de falsos positivos. Esta compensación puede ser preferible porque los falsos negativos permitirían que se realicen transacciones fraudulentas, mientras que los falsos positivos pueden hacer que se envíe un correo electrónico a un cliente para pedirle que verifique la actividad de su tarjeta.
Trazar la República de China
Ahora traza la República de China . Este gráfico es útil porque muestra, de un vistazo, el rango de rendimiento que puede alcanzar el modelo simplemente ajustando el umbral de salida.
def plot_roc(name, labels, predictions, **kwargs):
fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)
plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
plt.xlabel('False positives [%]')
plt.ylabel('True positives [%]')
plt.xlim([-0.5,20])
plt.ylim([80,100.5])
plt.grid(True)
ax = plt.gca()
ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa4786ca748>
Parece que la precisión es relativamente alta, pero la recuperación y el área bajo la curva ROC (AUC) no son tan altas como le gustaría. Los clasificadores a menudo enfrentan desafíos cuando intentan maximizar tanto la precisión como la recuperación, lo cual es especialmente cierto cuando se trabaja con conjuntos de datos desequilibrados. Es importante considerar los costos de los diferentes tipos de errores en el contexto del problema que le preocupa. En este ejemplo, un falso negativo (se pierde una transacción fraudulenta) puede tener un costo financiero, mientras que un falso positivo (una transacción se marca incorrectamente como fraudulenta) puede disminuir la felicidad del usuario.
Pesos de clase
Calcular pesos de clase
El objetivo es identificar transacciones fraudulentas, pero no tiene muchas de esas muestras positivas con las que trabajar, por lo que le conviene que el clasificador tenga en cuenta los pocos ejemplos disponibles. Puede hacer esto pasando pesos de Keras para cada clase a través de un parámetro. Esto hará que el modelo "preste más atención" a ejemplos de una clase subrepresentada.
# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0
weight_for_1 = (1 / pos)*(total)/2.0
class_weight = {0: weight_for_0, 1: weight_for_1}
print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50 Weight for class 1: 289.44
Entrena un modelo con pesos de clase
Ahora intente volver a entrenar y evaluar el modelo con ponderaciones de clase para ver cómo afecta eso a las predicciones.
weighted_model = make_model()
weighted_model.load_weights(initial_weights)
weighted_history = weighted_model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels),
# The class weights go here
class_weight=class_weight)
Epoch 1/100 90/90 [==============================] - 3s 15ms/step - loss: 3.7596 - tp: 56.9890 - fp: 76.4286 - tn: 150769.4286 - fn: 199.7253 - accuracy: 0.9983 - precision: 0.4816 - recall: 0.2595 - auc: 0.7156 - val_loss: 0.0096 - val_tp: 2.0000 - val_fp: 1.0000 - val_tn: 45488.0000 - val_fn: 78.0000 - val_accuracy: 0.9983 - val_precision: 0.6667 - val_recall: 0.0250 - val_auc: 0.8999 Epoch 2/100 90/90 [==============================] - 1s 7ms/step - loss: 1.5426 - tp: 57.2308 - fp: 247.2088 - tn: 93722.7582 - fn: 113.3736 - accuracy: 0.9964 - precision: 0.1813 - recall: 0.2976 - auc: 0.8189 - val_loss: 0.0089 - val_tp: 54.0000 - val_fp: 18.0000 - val_tn: 45471.0000 - val_fn: 26.0000 - val_accuracy: 0.9990 - val_precision: 0.7500 - val_recall: 0.6750 - val_auc: 0.9306 Epoch 3/100 90/90 [==============================] - 1s 7ms/step - loss: 0.8711 - tp: 95.8352 - fp: 494.1209 - tn: 93479.8681 - fn: 70.7473 - accuracy: 0.9943 - precision: 0.1692 - recall: 0.6059 - auc: 0.8912 - val_loss: 0.0121 - val_tp: 66.0000 - val_fp: 32.0000 - val_tn: 45457.0000 - val_fn: 14.0000 - val_accuracy: 0.9990 - val_precision: 0.6735 - val_recall: 0.8250 - val_auc: 0.9426 Epoch 4/100 90/90 [==============================] - 1s 7ms/step - loss: 0.6835 - tp: 108.5165 - fp: 794.1648 - tn: 93183.5055 - fn: 54.3846 - accuracy: 0.9912 - precision: 0.1191 - recall: 0.6530 - auc: 0.8987 - val_loss: 0.0163 - val_tp: 67.0000 - val_fp: 54.0000 - val_tn: 45435.0000 - val_fn: 13.0000 - val_accuracy: 0.9985 - val_precision: 0.5537 - val_recall: 0.8375 - val_auc: 0.9556 Epoch 5/100 90/90 [==============================] - 1s 7ms/step - loss: 0.4713 - tp: 126.3626 - fp: 1149.3407 - tn: 92827.5275 - fn: 37.3407 - accuracy: 0.9878 - precision: 0.0992 - recall: 0.7828 - auc: 0.9329 - val_loss: 0.0214 - val_tp: 67.0000 - val_fp: 95.0000 - val_tn: 45394.0000 - val_fn: 13.0000 - val_accuracy: 0.9976 - val_precision: 0.4136 - val_recall: 0.8375 - val_auc: 0.9588 Epoch 6/100 90/90 [==============================] - 1s 7ms/step - loss: 0.4194 - tp: 125.7912 - fp: 1550.7253 - tn: 92430.8791 - fn: 33.1758 - accuracy: 0.9837 - precision: 0.0769 - recall: 0.7990 - auc: 0.9373 - val_loss: 0.0270 - val_tp: 67.0000 - val_fp: 147.0000 - val_tn: 45342.0000 - val_fn: 13.0000 - val_accuracy: 0.9965 - val_precision: 0.3131 - val_recall: 0.8375 - val_auc: 0.9626 Epoch 7/100 90/90 [==============================] - 1s 7ms/step - loss: 0.4226 - tp: 127.0659 - fp: 2000.6374 - tn: 91978.2857 - fn: 34.5824 - accuracy: 0.9788 - precision: 0.0567 - recall: 0.7672 - auc: 0.9351 - val_loss: 0.0348 - val_tp: 67.0000 - val_fp: 224.0000 - val_tn: 45265.0000 - val_fn: 13.0000 - val_accuracy: 0.9948 - val_precision: 0.2302 - val_recall: 0.8375 - val_auc: 0.9656 Epoch 8/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2990 - tp: 137.7473 - fp: 2564.6154 - tn: 91411.7363 - fn: 26.4725 - accuracy: 0.9729 - precision: 0.0528 - recall: 0.8483 - auc: 0.9609 - val_loss: 0.0457 - val_tp: 69.0000 - val_fp: 406.0000 - val_tn: 45083.0000 - val_fn: 11.0000 - val_accuracy: 0.9908 - val_precision: 0.1453 - val_recall: 0.8625 - val_auc: 0.9691 Epoch 9/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3165 - tp: 125.0330 - fp: 3192.7473 - tn: 90795.8462 - fn: 26.9451 - accuracy: 0.9662 - precision: 0.0375 - recall: 0.8237 - auc: 0.9518 - val_loss: 0.0568 - val_tp: 69.0000 - val_fp: 654.0000 - val_tn: 44835.0000 - val_fn: 11.0000 - val_accuracy: 0.9854 - val_precision: 0.0954 - val_recall: 0.8625 - val_auc: 0.9703 Epoch 10/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2917 - tp: 139.2527 - fp: 3709.6154 - tn: 90270.3626 - fn: 21.3407 - accuracy: 0.9606 - precision: 0.0356 - recall: 0.8714 - auc: 0.9546 - val_loss: 0.0640 - val_tp: 70.0000 - val_fp: 755.0000 - val_tn: 44734.0000 - val_fn: 10.0000 - val_accuracy: 0.9832 - val_precision: 0.0848 - val_recall: 0.8750 - val_auc: 0.9713 Epoch 11/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3109 - tp: 155.7363 - fp: 3838.3736 - tn: 90123.9670 - fn: 22.4945 - accuracy: 0.9590 - precision: 0.0415 - recall: 0.8730 - auc: 0.9555 - val_loss: 0.0703 - val_tp: 71.0000 - val_fp: 835.0000 - val_tn: 44654.0000 - val_fn: 9.0000 - val_accuracy: 0.9815 - val_precision: 0.0784 - val_recall: 0.8875 - val_auc: 0.9719 Epoch 12/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3007 - tp: 134.9121 - fp: 4045.0549 - tn: 89938.1758 - fn: 22.4286 - accuracy: 0.9566 - precision: 0.0320 - recall: 0.8712 - auc: 0.9494 - val_loss: 0.0756 - val_tp: 72.0000 - val_fp: 910.0000 - val_tn: 44579.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0733 - val_recall: 0.9000 - val_auc: 0.9716 Epoch 13/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2083 - tp: 143.2308 - fp: 4041.6593 - tn: 89938.9670 - fn: 16.7143 - accuracy: 0.9567 - precision: 0.0360 - recall: 0.9154 - auc: 0.9734 - val_loss: 0.0765 - val_tp: 72.0000 - val_fp: 916.0000 - val_tn: 44573.0000 - val_fn: 8.0000 - val_accuracy: 0.9797 - val_precision: 0.0729 - val_recall: 0.9000 - val_auc: 0.9726 Epoch 14/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2448 - tp: 143.6703 - fp: 4119.2967 - tn: 89857.6154 - fn: 19.9890 - accuracy: 0.9562 - precision: 0.0341 - recall: 0.8944 - auc: 0.9655 - val_loss: 0.0811 - val_tp: 72.0000 - val_fp: 992.0000 - val_tn: 44497.0000 - val_fn: 8.0000 - val_accuracy: 0.9781 - val_precision: 0.0677 - val_recall: 0.9000 - val_auc: 0.9724 Epoch 15/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2362 - tp: 141.0769 - fp: 4205.1429 - tn: 89776.9670 - fn: 17.3846 - accuracy: 0.9545 - precision: 0.0316 - recall: 0.8889 - auc: 0.9665 - val_loss: 0.0835 - val_tp: 72.0000 - val_fp: 1019.0000 - val_tn: 44470.0000 - val_fn: 8.0000 - val_accuracy: 0.9775 - val_precision: 0.0660 - val_recall: 0.9000 - val_auc: 0.9729 Epoch 16/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2100 - tp: 135.4945 - fp: 4278.8242 - tn: 89707.5495 - fn: 18.7033 - accuracy: 0.9542 - precision: 0.0289 - recall: 0.8980 - auc: 0.9717 - val_loss: 0.0922 - val_tp: 72.0000 - val_fp: 1117.0000 - val_tn: 44372.0000 - val_fn: 8.0000 - val_accuracy: 0.9753 - val_precision: 0.0606 - val_recall: 0.9000 - val_auc: 0.9728 Epoch 17/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2061 - tp: 147.9451 - fp: 4488.5604 - tn: 89486.7912 - fn: 17.2747 - accuracy: 0.9519 - precision: 0.0328 - recall: 0.9026 - auc: 0.9754 - val_loss: 0.0911 - val_tp: 72.0000 - val_fp: 1104.0000 - val_tn: 44385.0000 - val_fn: 8.0000 - val_accuracy: 0.9756 - val_precision: 0.0612 - val_recall: 0.9000 - val_auc: 0.9730 Epoch 18/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3032 - tp: 143.0989 - fp: 4367.9890 - tn: 89610.3956 - fn: 19.0879 - accuracy: 0.9529 - precision: 0.0312 - recall: 0.8782 - auc: 0.9486 - val_loss: 0.0878 - val_tp: 72.0000 - val_fp: 1037.0000 - val_tn: 44452.0000 - val_fn: 8.0000 - val_accuracy: 0.9771 - val_precision: 0.0649 - val_recall: 0.9000 - val_auc: 0.9761 Epoch 19/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2713 - tp: 148.5165 - fp: 4079.8462 - tn: 89894.0110 - fn: 18.1978 - accuracy: 0.9565 - precision: 0.0361 - recall: 0.8861 - auc: 0.9635 - val_loss: 0.0868 - val_tp: 72.0000 - val_fp: 1011.0000 - val_tn: 44478.0000 - val_fn: 8.0000 - val_accuracy: 0.9776 - val_precision: 0.0665 - val_recall: 0.9000 - val_auc: 0.9762 Epoch 20/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2623 - tp: 144.5055 - fp: 3991.8681 - tn: 89984.6484 - fn: 19.5495 - accuracy: 0.9575 - precision: 0.0332 - recall: 0.8837 - auc: 0.9587 - val_loss: 0.0888 - val_tp: 72.0000 - val_fp: 1030.0000 - val_tn: 44459.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0653 - val_recall: 0.9000 - val_auc: 0.9761 Epoch 21/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3103 - tp: 136.6154 - fp: 3966.2527 - tn: 90015.5934 - fn: 22.1099 - accuracy: 0.9577 - precision: 0.0309 - recall: 0.8330 - auc: 0.9373 - val_loss: 0.0886 - val_tp: 72.0000 - val_fp: 1010.0000 - val_tn: 44479.0000 - val_fn: 8.0000 - val_accuracy: 0.9777 - val_precision: 0.0665 - val_recall: 0.9000 - val_auc: 0.9764 Epoch 22/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2215 - tp: 141.6484 - fp: 3872.5055 - tn: 90108.2527 - fn: 18.1648 - accuracy: 0.9584 - precision: 0.0347 - recall: 0.8935 - auc: 0.9698 - val_loss: 0.0862 - val_tp: 72.0000 - val_fp: 972.0000 - val_tn: 44517.0000 - val_fn: 8.0000 - val_accuracy: 0.9785 - val_precision: 0.0690 - val_recall: 0.9000 - val_auc: 0.9776 Epoch 23/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2214 - tp: 141.6593 - fp: 3877.3626 - tn: 90102.6593 - fn: 18.8901 - accuracy: 0.9585 - precision: 0.0354 - recall: 0.8872 - auc: 0.9685 - val_loss: 0.0842 - val_tp: 72.0000 - val_fp: 938.0000 - val_tn: 44551.0000 - val_fn: 8.0000 - val_accuracy: 0.9792 - val_precision: 0.0713 - val_recall: 0.9000 - val_auc: 0.9777 Epoch 24/100 90/90 [==============================] - 1s 7ms/step - loss: 0.1873 - tp: 138.3956 - fp: 3647.7473 - tn: 90337.8681 - fn: 16.5604 - accuracy: 0.9611 - precision: 0.0340 - recall: 0.9088 - auc: 0.9743 - val_loss: 0.0843 - val_tp: 73.0000 - val_fp: 938.0000 - val_tn: 44551.0000 - val_fn: 7.0000 - val_accuracy: 0.9793 - val_precision: 0.0722 - val_recall: 0.9125 - val_auc: 0.9779 Epoch 25/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2038 - tp: 140.3407 - fp: 3553.5824 - tn: 90428.6813 - fn: 17.9670 - accuracy: 0.9620 - precision: 0.0383 - recall: 0.8946 - auc: 0.9739 - val_loss: 0.0854 - val_tp: 73.0000 - val_fp: 942.0000 - val_tn: 44547.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0719 - val_recall: 0.9125 - val_auc: 0.9777 Epoch 26/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2003 - tp: 151.4725 - fp: 3495.4725 - tn: 90479.3516 - fn: 14.2747 - accuracy: 0.9625 - precision: 0.0406 - recall: 0.9196 - auc: 0.9732 - val_loss: 0.0819 - val_tp: 73.0000 - val_fp: 895.0000 - val_tn: 44594.0000 - val_fn: 7.0000 - val_accuracy: 0.9802 - val_precision: 0.0754 - val_recall: 0.9125 - val_auc: 0.9778 Epoch 27/100 90/90 [==============================] - 1s 9ms/step - loss: 0.2111 - tp: 138.2857 - fp: 3438.2088 - tn: 90547.7253 - fn: 16.3516 - accuracy: 0.9635 - precision: 0.0382 - recall: 0.9021 - auc: 0.9690 - val_loss: 0.0865 - val_tp: 73.0000 - val_fp: 940.0000 - val_tn: 44549.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0721 - val_recall: 0.9125 - val_auc: 0.9775 Epoch 28/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2115 - tp: 146.6813 - fp: 3796.8022 - tn: 90178.5604 - fn: 18.5275 - accuracy: 0.9595 - precision: 0.0357 - recall: 0.8799 - auc: 0.9740 - val_loss: 0.0914 - val_tp: 73.0000 - val_fp: 996.0000 - val_tn: 44493.0000 - val_fn: 7.0000 - val_accuracy: 0.9780 - val_precision: 0.0683 - val_recall: 0.9125 - val_auc: 0.9774 Epoch 29/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2085 - tp: 142.0220 - fp: 3757.9341 - tn: 90222.3956 - fn: 18.2198 - accuracy: 0.9597 - precision: 0.0368 - recall: 0.8941 - auc: 0.9738 - val_loss: 0.0890 - val_tp: 73.0000 - val_fp: 962.0000 - val_tn: 44527.0000 - val_fn: 7.0000 - val_accuracy: 0.9787 - val_precision: 0.0705 - val_recall: 0.9125 - val_auc: 0.9775 Epoch 30/100 90/90 [==============================] - 1s 7ms/step - loss: 0.1559 - tp: 153.3077 - fp: 3571.6044 - tn: 90404.6264 - fn: 11.0330 - accuracy: 0.9617 - precision: 0.0420 - recall: 0.9426 - auc: 0.9811 - val_loss: 0.0852 - val_tp: 73.0000 - val_fp: 901.0000 - val_tn: 44588.0000 - val_fn: 7.0000 - val_accuracy: 0.9801 - val_precision: 0.0749 - val_recall: 0.9125 - val_auc: 0.9778 Epoch 31/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2635 - tp: 142.1209 - fp: 3672.6264 - tn: 90304.4725 - fn: 21.3516 - accuracy: 0.9613 - precision: 0.0376 - recall: 0.8706 - auc: 0.9641 - val_loss: 0.0930 - val_tp: 73.0000 - val_fp: 1011.0000 - val_tn: 44478.0000 - val_fn: 7.0000 - val_accuracy: 0.9777 - val_precision: 0.0673 - val_recall: 0.9125 - val_auc: 0.9773 Epoch 32/100 90/90 [==============================] - 1s 7ms/step - loss: 0.1858 - tp: 155.9121 - fp: 3632.0440 - tn: 90338.0769 - fn: 14.5385 - accuracy: 0.9609 - precision: 0.0410 - recall: 0.9196 - auc: 0.9791 - val_loss: 0.0897 - val_tp: 73.0000 - val_fp: 972.0000 - val_tn: 44517.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0699 - val_recall: 0.9125 - val_auc: 0.9774 Epoch 33/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2525 - tp: 152.4945 - fp: 3743.6264 - tn: 90226.1209 - fn: 18.3297 - accuracy: 0.9599 - precision: 0.0398 - recall: 0.8872 - auc: 0.9646 - val_loss: 0.0878 - val_tp: 73.0000 - val_fp: 930.0000 - val_tn: 44559.0000 - val_fn: 7.0000 - val_accuracy: 0.9794 - val_precision: 0.0728 - val_recall: 0.9125 - val_auc: 0.9774 Epoch 34/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2286 - tp: 142.4615 - fp: 3510.0659 - tn: 90469.1209 - fn: 18.9231 - accuracy: 0.9626 - precision: 0.0379 - recall: 0.8608 - auc: 0.9694 - val_loss: 0.0839 - val_tp: 73.0000 - val_fp: 889.0000 - val_tn: 44600.0000 - val_fn: 7.0000 - val_accuracy: 0.9803 - val_precision: 0.0759 - val_recall: 0.9125 - val_auc: 0.9773 Restoring model weights from the end of the best epoch. Epoch 00034: early stopping
Consultar historial de entrenamiento
plot_metrics(weighted_history)
Evaluar métricas
train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_weighted)
loss : 0.08238842338323593 tp : 89.0 fp : 1186.0 tn : 55681.0 fn : 6.0 accuracy : 0.9790737628936768 precision : 0.06980392336845398 recall : 0.9368420839309692 auc : 0.98465895652771 Legitimate Transactions Detected (True Negatives): 55681 Legitimate Transactions Incorrectly Detected (False Positives): 1186 Fraudulent Transactions Missed (False Negatives): 6 Fraudulent Transactions Detected (True Positives): 89 Total Fraudulent Transactions: 95
Aquí puede ver que con las ponderaciones de clase, la exactitud y la precisión son menores porque hay más falsos positivos, pero a la inversa, la recuperación y el AUC son mayores porque el modelo también encontró más verdaderos positivos. A pesar de tener una precisión menor, este modelo tiene una mayor recuperación (e identifica más transacciones fraudulentas). Por supuesto, ambos tipos de error tienen un costo (tampoco querrá molestar a los usuarios marcando demasiadas transacciones legítimas como fraudulentas). Considere cuidadosamente las compensaciones entre estos diferentes tipos de errores para su aplicación.
Trazar la República de China
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa4fc124e48>
Sobremuestreo
Sobremuestrear la clase minoritaria
Un enfoque relacionado sería volver a muestrear el conjunto de datos sobremuestreando la clase minoritaria.
pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]
pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
Usando NumPy
Puede equilibrar el conjunto de datos manualmente eligiendo el número correcto de índices aleatorios de los ejemplos positivos:
ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))
res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]
res_pos_features.shape
(181959, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)
order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]
resampled_features.shape
(363918, 29)
Usando tf.data
Si está utilizando tf.data
la forma más fácil de producir ejemplos equilibrados es comenzar con un conjunto de datos positive
y negative
y fusionarlos. Consulte la guía tf.data para obtener más ejemplos.
BUFFER_SIZE = 100000
def make_ds(features, labels):
ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
ds = ds.shuffle(BUFFER_SIZE).repeat()
return ds
pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)
Cada conjunto de datos proporciona pares (feature, label)
:
for features, label in pos_ds.take(1):
print("Features:\n", features.numpy())
print()
print("Label: ", label.numpy())
Features: [-2.19450702e+00 1.14157653e+00 -1.53868568e+00 -3.35790031e-01 -8.45956971e-01 -1.57966490e+00 -1.68436379e+00 2.25259298e-01 2.62655847e-01 -3.51815449e+00 2.45653756e+00 -3.62123216e+00 -1.02812838e+00 -5.00000000e+00 1.79960408e+00 -3.72675038e+00 -5.00000000e+00 -2.22257320e+00 -1.24126989e-02 -9.14798839e-01 7.33686754e-01 -9.34731229e-02 -1.74683429e+00 4.44394133e-01 -3.98093307e-02 -1.99985035e+00 -2.27011783e+00 1.83260825e-03 5.69563522e-01] Label: 1
Combina los dos juntos usando experimental.sample_from_datasets
:
resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
print(label.numpy().mean())
0.5
Para usar este conjunto de datos, necesitará la cantidad de pasos por época.
La definición de "época" en este caso es menos clara. Supongamos que es la cantidad de lotes necesarios para ver cada ejemplo negativo una vez:
resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0
Entrene con los datos sobremuestreados
Ahora intente entrenar el modelo con el conjunto de datos remuestreados en lugar de usar ponderaciones de clase para ver cómo se comparan estos métodos.
resampled_model = make_model()
resampled_model.load_weights(initial_weights)
# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1]
output_layer.bias.assign([0])
val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2)
resampled_history = resampled_model.fit(
resampled_ds,
epochs=EPOCHS,
steps_per_epoch=resampled_steps_per_epoch,
callbacks=[early_stopping],
validation_data=val_ds)
Epoch 1/100 278/278 [==============================] - 8s 24ms/step - loss: 0.8118 - tp: 98159.0036 - fp: 34959.3333 - tn: 165642.9498 - fn: 44913.3728 - accuracy: 0.7582 - precision: 0.6573 - recall: 0.5935 - auc: 0.7814 - val_loss: 0.1785 - val_tp: 70.0000 - val_fp: 1177.0000 - val_tn: 44312.0000 - val_fn: 10.0000 - val_accuracy: 0.9740 - val_precision: 0.0561 - val_recall: 0.8750 - val_auc: 0.9722 Epoch 2/100 278/278 [==============================] - 6s 21ms/step - loss: 0.2142 - tp: 126477.2616 - fp: 8422.5305 - tn: 134479.1362 - fn: 17333.7312 - accuracy: 0.9075 - precision: 0.9341 - recall: 0.8774 - auc: 0.9705 - val_loss: 0.1007 - val_tp: 71.0000 - val_fp: 919.0000 - val_tn: 44570.0000 - val_fn: 9.0000 - val_accuracy: 0.9796 - val_precision: 0.0717 - val_recall: 0.8875 - val_auc: 0.9718 Epoch 3/100 278/278 [==============================] - 6s 21ms/step - loss: 0.1652 - tp: 128160.7455 - fp: 6029.4875 - tn: 137362.8602 - fn: 15159.5663 - accuracy: 0.9255 - precision: 0.9544 - recall: 0.8936 - auc: 0.9831 - val_loss: 0.0775 - val_tp: 71.0000 - val_fp: 802.0000 - val_tn: 44687.0000 - val_fn: 9.0000 - val_accuracy: 0.9822 - val_precision: 0.0813 - val_recall: 0.8875 - val_auc: 0.9727 Epoch 4/100 278/278 [==============================] - 6s 21ms/step - loss: 0.1451 - tp: 129069.5376 - fp: 5223.8925 - tn: 138326.3763 - fn: 14092.8530 - accuracy: 0.9322 - precision: 0.9611 - recall: 0.9007 - auc: 0.9872 - val_loss: 0.0676 - val_tp: 71.0000 - val_fp: 756.0000 - val_tn: 44733.0000 - val_fn: 9.0000 - val_accuracy: 0.9832 - val_precision: 0.0859 - val_recall: 0.8875 - val_auc: 0.9766 Epoch 5/100 278/278 [==============================] - 6s 22ms/step - loss: 0.1328 - tp: 130130.0502 - fp: 5189.8996 - tn: 138364.5878 - fn: 13028.1219 - accuracy: 0.9360 - precision: 0.9611 - recall: 0.9085 - auc: 0.9895 - val_loss: 0.0616 - val_tp: 72.0000 - val_fp: 763.0000 - val_tn: 44726.0000 - val_fn: 8.0000 - val_accuracy: 0.9831 - val_precision: 0.0862 - val_recall: 0.9000 - val_auc: 0.9748 Epoch 6/100 278/278 [==============================] - 6s 22ms/step - loss: 0.1230 - tp: 131300.7706 - fp: 4925.0179 - tn: 138645.5125 - fn: 11841.3584 - accuracy: 0.9413 - precision: 0.9635 - recall: 0.9170 - auc: 0.9912 - val_loss: 0.0566 - val_tp: 72.0000 - val_fp: 754.0000 - val_tn: 44735.0000 - val_fn: 8.0000 - val_accuracy: 0.9833 - val_precision: 0.0872 - val_recall: 0.9000 - val_auc: 0.9759 Epoch 7/100 278/278 [==============================] - 6s 22ms/step - loss: 0.1150 - tp: 132477.4875 - fp: 4822.2509 - tn: 138449.6989 - fn: 10963.2222 - accuracy: 0.9445 - precision: 0.9647 - recall: 0.9228 - auc: 0.9925 - val_loss: 0.0516 - val_tp: 72.0000 - val_fp: 711.0000 - val_tn: 44778.0000 - val_fn: 8.0000 - val_accuracy: 0.9842 - val_precision: 0.0920 - val_recall: 0.9000 - val_auc: 0.9777 Epoch 8/100 278/278 [==============================] - 6s 23ms/step - loss: 0.1069 - tp: 132886.1828 - fp: 4656.9570 - tn: 138848.8029 - fn: 10320.7168 - accuracy: 0.9474 - precision: 0.9660 - recall: 0.9275 - auc: 0.9936 - val_loss: 0.0462 - val_tp: 71.0000 - val_fp: 667.0000 - val_tn: 44822.0000 - val_fn: 9.0000 - val_accuracy: 0.9852 - val_precision: 0.0962 - val_recall: 0.8875 - val_auc: 0.9695 Epoch 9/100 278/278 [==============================] - 6s 22ms/step - loss: 0.1006 - tp: 133579.5950 - fp: 4598.8602 - tn: 138524.1039 - fn: 10010.1004 - accuracy: 0.9489 - precision: 0.9668 - recall: 0.9298 - auc: 0.9944 - val_loss: 0.0419 - val_tp: 71.0000 - val_fp: 639.0000 - val_tn: 44850.0000 - val_fn: 9.0000 - val_accuracy: 0.9858 - val_precision: 0.1000 - val_recall: 0.8875 - val_auc: 0.9703 Epoch 10/100 278/278 [==============================] - 6s 22ms/step - loss: 0.0955 - tp: 134632.2903 - fp: 4540.9355 - tn: 138577.5520 - fn: 8961.8817 - accuracy: 0.9521 - precision: 0.9676 - recall: 0.9358 - auc: 0.9950 - val_loss: 0.0396 - val_tp: 71.0000 - val_fp: 634.0000 - val_tn: 44855.0000 - val_fn: 9.0000 - val_accuracy: 0.9859 - val_precision: 0.1007 - val_recall: 0.8875 - val_auc: 0.9664 Epoch 11/100 278/278 [==============================] - 6s 22ms/step - loss: 0.0908 - tp: 140213.1900 - fp: 5993.7849 - tn: 136973.6380 - fn: 3532.0466 - accuracy: 0.9663 - precision: 0.9588 - recall: 0.9747 - auc: 0.9955 - val_loss: 0.0359 - val_tp: 71.0000 - val_fp: 596.0000 - val_tn: 44893.0000 - val_fn: 9.0000 - val_accuracy: 0.9867 - val_precision: 0.1064 - val_recall: 0.8875 - val_auc: 0.9669 Epoch 12/100 278/278 [==============================] - 6s 21ms/step - loss: 0.0857 - tp: 140261.6918 - fp: 5979.1470 - tn: 137657.2294 - fn: 2814.5914 - accuracy: 0.9691 - precision: 0.9589 - recall: 0.9800 - auc: 0.9959 - val_loss: 0.0337 - val_tp: 71.0000 - val_fp: 580.0000 - val_tn: 44909.0000 - val_fn: 9.0000 - val_accuracy: 0.9871 - val_precision: 0.1091 - val_recall: 0.8875 - val_auc: 0.9677 Epoch 13/100 278/278 [==============================] - 6s 21ms/step - loss: 0.0818 - tp: 140796.7133 - fp: 5945.5305 - tn: 137458.5520 - fn: 2511.8638 - accuracy: 0.9704 - precision: 0.9596 - recall: 0.9822 - auc: 0.9962 - val_loss: 0.0318 - val_tp: 71.0000 - val_fp: 558.0000 - val_tn: 44931.0000 - val_fn: 9.0000 - val_accuracy: 0.9876 - val_precision: 0.1129 - val_recall: 0.8875 - val_auc: 0.9682 Epoch 14/100 278/278 [==============================] - 7s 24ms/step - loss: 0.0793 - tp: 140997.9176 - fp: 6076.8746 - tn: 137562.6846 - fn: 2075.1828 - accuracy: 0.9714 - precision: 0.9586 - recall: 0.9853 - auc: 0.9964 - val_loss: 0.0303 - val_tp: 71.0000 - val_fp: 555.0000 - val_tn: 44934.0000 - val_fn: 9.0000 - val_accuracy: 0.9876 - val_precision: 0.1134 - val_recall: 0.8875 - val_auc: 0.9687 Epoch 15/100 278/278 [==============================] - 6s 22ms/step - loss: 0.0759 - tp: 141966.7312 - fp: 6100.1147 - tn: 136957.0108 - fn: 1688.8029 - accuracy: 0.9729 - precision: 0.9589 - recall: 0.9883 - auc: 0.9966 - val_loss: 0.0292 - val_tp: 71.0000 - val_fp: 541.0000 - val_tn: 44948.0000 - val_fn: 9.0000 - val_accuracy: 0.9879 - val_precision: 0.1160 - val_recall: 0.8875 - val_auc: 0.9692 Epoch 16/100 278/278 [==============================] - 6s 21ms/step - loss: 0.0727 - tp: 141879.4731 - fp: 6020.2007 - tn: 137272.9140 - fn: 1540.0717 - accuracy: 0.9736 - precision: 0.9594 - recall: 0.9891 - auc: 0.9968 - val_loss: 0.0276 - val_tp: 71.0000 - val_fp: 504.0000 - val_tn: 44985.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 0.1235 - val_recall: 0.8875 - val_auc: 0.9697 Epoch 17/100 278/278 [==============================] - 6s 21ms/step - loss: 0.0704 - tp: 142140.7563 - fp: 6019.2258 - tn: 137225.3978 - fn: 1327.2796 - accuracy: 0.9745 - precision: 0.9597 - recall: 0.9907 - auc: 0.9969 - val_loss: 0.0269 - val_tp: 71.0000 - val_fp: 495.0000 - val_tn: 44994.0000 - val_fn: 9.0000 - val_accuracy: 0.9889 - val_precision: 0.1254 - val_recall: 0.8875 - val_auc: 0.9699 Restoring model weights from the end of the best epoch. Epoch 00017: early stopping
Si el proceso de entrenamiento tuviera en cuenta el conjunto de datos completo en cada actualización de gradiente, este sobremuestreo sería básicamente idéntico a la ponderación de la clase.
Pero al entrenar el modelo por lotes, como lo hizo aquí, los datos sobremuestreados proporcionan una señal de gradiente más suave: en lugar de que cada ejemplo positivo se muestre en un lote con un gran peso, se muestran en muchos lotes diferentes cada vez con un pequeño peso.
Esta señal de gradiente más suave facilita el entrenamiento del modelo.
Consultar historial de entrenamiento
Tenga en cuenta que las distribuciones de métricas serán diferentes aquí, porque los datos de entrenamiento tienen una distribución totalmente diferente de los datos de validación y prueba.
plot_metrics(resampled_history)
Reentrenar
Debido a que el entrenamiento es más fácil con los datos balanceados, el procedimiento de entrenamiento anterior puede sobreajustarse rápidamente.
Así que divida las épocas para devolver las callbacks.EarlyStopping
Temprano Detener un control más callbacks.EarlyStopping
sobre cuándo dejar de entrenar.
resampled_model = make_model()
resampled_model.load_weights(initial_weights)
# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1]
output_layer.bias.assign([0])
resampled_history = resampled_model.fit(
resampled_ds,
# These are not real epochs
steps_per_epoch=20,
epochs=10*EPOCHS,
callbacks=[early_stopping],
validation_data=(val_ds))
Epoch 1/1000 20/20 [==============================] - 3s 59ms/step - loss: 1.6583 - tp: 2292.0000 - fp: 4139.7619 - tn: 52605.9524 - fn: 8961.7619 - accuracy: 0.8197 - precision: 0.3317 - recall: 0.1958 - auc: 0.7857 - val_loss: 0.5873 - val_tp: 9.0000 - val_fp: 13094.0000 - val_tn: 32395.0000 - val_fn: 71.0000 - val_accuracy: 0.7111 - val_precision: 6.8687e-04 - val_recall: 0.1125 - val_auc: 0.2616 Epoch 2/1000 20/20 [==============================] - 1s 27ms/step - loss: 1.0305 - tp: 4655.6667 - fp: 3775.2857 - tn: 7528.4286 - fn: 6471.0952 - accuracy: 0.5318 - precision: 0.5347 - recall: 0.3960 - auc: 0.4700 - val_loss: 0.5752 - val_tp: 45.0000 - val_fp: 12493.0000 - val_tn: 32996.0000 - val_fn: 35.0000 - val_accuracy: 0.7251 - val_precision: 0.0036 - val_recall: 0.5625 - val_auc: 0.7427 Epoch 3/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.6911 - tp: 7407.9048 - fp: 3609.9048 - tn: 7705.3333 - fn: 3707.3333 - accuracy: 0.6637 - precision: 0.6637 - recall: 0.6487 - auc: 0.6971 - val_loss: 0.5367 - val_tp: 73.0000 - val_fp: 10392.0000 - val_tn: 35097.0000 - val_fn: 7.0000 - val_accuracy: 0.7718 - val_precision: 0.0070 - val_recall: 0.9125 - val_auc: 0.9205 Epoch 4/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.5257 - tp: 8758.5238 - fp: 3223.3333 - tn: 8031.5714 - fn: 2417.0476 - accuracy: 0.7432 - precision: 0.7273 - recall: 0.7793 - auc: 0.8200 - val_loss: 0.4801 - val_tp: 73.0000 - val_fp: 7752.0000 - val_tn: 37737.0000 - val_fn: 7.0000 - val_accuracy: 0.8297 - val_precision: 0.0093 - val_recall: 0.9125 - val_auc: 0.9436 Epoch 5/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.4499 - tp: 9241.0952 - fp: 2761.1905 - tn: 8463.3333 - fn: 1964.8571 - accuracy: 0.7862 - precision: 0.7667 - recall: 0.8219 - auc: 0.8704 - val_loss: 0.4242 - val_tp: 72.0000 - val_fp: 5662.0000 - val_tn: 39827.0000 - val_fn: 8.0000 - val_accuracy: 0.8756 - val_precision: 0.0126 - val_recall: 0.9000 - val_auc: 0.9524 Epoch 6/1000 20/20 [==============================] - 1s 28ms/step - loss: 0.4029 - tp: 9361.0000 - fp: 2311.8571 - tn: 8961.6190 - fn: 1796.0000 - accuracy: 0.8140 - precision: 0.7965 - recall: 0.8389 - auc: 0.8968 - val_loss: 0.3750 - val_tp: 71.0000 - val_fp: 4105.0000 - val_tn: 41384.0000 - val_fn: 9.0000 - val_accuracy: 0.9097 - val_precision: 0.0170 - val_recall: 0.8875 - val_auc: 0.9583 Epoch 7/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.3601 - tp: 9456.8571 - fp: 1812.6667 - tn: 9376.6190 - fn: 1784.3333 - accuracy: 0.8386 - precision: 0.8376 - recall: 0.8413 - auc: 0.9181 - val_loss: 0.3324 - val_tp: 70.0000 - val_fp: 3047.0000 - val_tn: 42442.0000 - val_fn: 10.0000 - val_accuracy: 0.9329 - val_precision: 0.0225 - val_recall: 0.8750 - val_auc: 0.9626 Epoch 8/1000 20/20 [==============================] - 1s 28ms/step - loss: 0.3373 - tp: 9506.3810 - fp: 1623.5714 - tn: 9568.3333 - fn: 1732.1905 - accuracy: 0.8490 - precision: 0.8520 - recall: 0.8453 - auc: 0.9266 - val_loss: 0.2959 - val_tp: 69.0000 - val_fp: 2358.0000 - val_tn: 43131.0000 - val_fn: 11.0000 - val_accuracy: 0.9480 - val_precision: 0.0284 - val_recall: 0.8625 - val_auc: 0.9660 Epoch 9/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.3120 - tp: 9635.0000 - fp: 1347.7619 - tn: 9797.4286 - fn: 1650.2857 - accuracy: 0.8662 - precision: 0.8776 - recall: 0.8534 - auc: 0.9366 - val_loss: 0.2670 - val_tp: 69.0000 - val_fp: 1943.0000 - val_tn: 43546.0000 - val_fn: 11.0000 - val_accuracy: 0.9571 - val_precision: 0.0343 - val_recall: 0.8625 - val_auc: 0.9689 Epoch 10/1000 20/20 [==============================] - 1s 28ms/step - loss: 0.2923 - tp: 9632.5714 - fp: 1221.3333 - tn: 9978.5714 - fn: 1598.0000 - accuracy: 0.8725 - precision: 0.8864 - recall: 0.8559 - auc: 0.9446 - val_loss: 0.2417 - val_tp: 69.0000 - val_fp: 1609.0000 - val_tn: 43880.0000 - val_fn: 11.0000 - val_accuracy: 0.9644 - val_precision: 0.0411 - val_recall: 0.8625 - val_auc: 0.9706 Epoch 11/1000 20/20 [==============================] - 1s 29ms/step - loss: 0.2765 - tp: 9632.7619 - fp: 1069.8095 - tn: 10155.4762 - fn: 1572.4286 - accuracy: 0.8806 - precision: 0.8976 - recall: 0.8584 - auc: 0.9505 - val_loss: 0.2211 - val_tp: 69.0000 - val_fp: 1424.0000 - val_tn: 44065.0000 - val_fn: 11.0000 - val_accuracy: 0.9685 - val_precision: 0.0462 - val_recall: 0.8625 - val_auc: 0.9716 Epoch 12/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2624 - tp: 9721.6190 - fp: 949.2857 - tn: 10243.6190 - fn: 1515.9524 - accuracy: 0.8901 - precision: 0.9102 - recall: 0.8658 - auc: 0.9556 - val_loss: 0.2038 - val_tp: 69.0000 - val_fp: 1299.0000 - val_tn: 44190.0000 - val_fn: 11.0000 - val_accuracy: 0.9713 - val_precision: 0.0504 - val_recall: 0.8625 - val_auc: 0.9722 Epoch 13/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2491 - tp: 9862.2381 - fp: 851.1905 - tn: 10195.6667 - fn: 1521.3810 - accuracy: 0.8933 - precision: 0.9207 - recall: 0.8656 - auc: 0.9597 - val_loss: 0.1904 - val_tp: 70.0000 - val_fp: 1230.0000 - val_tn: 44259.0000 - val_fn: 10.0000 - val_accuracy: 0.9728 - val_precision: 0.0538 - val_recall: 0.8750 - val_auc: 0.9726 Epoch 14/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2384 - tp: 9752.6190 - fp: 806.5714 - tn: 10438.3333 - fn: 1432.9524 - accuracy: 0.9006 - precision: 0.9238 - recall: 0.8724 - auc: 0.9637 - val_loss: 0.1781 - val_tp: 70.0000 - val_fp: 1186.0000 - val_tn: 44303.0000 - val_fn: 10.0000 - val_accuracy: 0.9738 - val_precision: 0.0557 - val_recall: 0.8750 - val_auc: 0.9724 Epoch 15/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.2332 - tp: 9727.1905 - fp: 787.1905 - tn: 10508.9524 - fn: 1407.1429 - accuracy: 0.9024 - precision: 0.9255 - recall: 0.8737 - auc: 0.9651 - val_loss: 0.1664 - val_tp: 70.0000 - val_fp: 1130.0000 - val_tn: 44359.0000 - val_fn: 10.0000 - val_accuracy: 0.9750 - val_precision: 0.0583 - val_recall: 0.8750 - val_auc: 0.9728 Epoch 16/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2242 - tp: 9812.1905 - fp: 731.9524 - tn: 10481.3333 - fn: 1405.0000 - accuracy: 0.9048 - precision: 0.9310 - recall: 0.8745 - auc: 0.9677 - val_loss: 0.1561 - val_tp: 70.0000 - val_fp: 1085.0000 - val_tn: 44404.0000 - val_fn: 10.0000 - val_accuracy: 0.9760 - val_precision: 0.0606 - val_recall: 0.8750 - val_auc: 0.9725 Epoch 17/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2197 - tp: 9864.8571 - fp: 691.0000 - tn: 10487.2381 - fn: 1387.3810 - accuracy: 0.9072 - precision: 0.9350 - recall: 0.8766 - auc: 0.9690 - val_loss: 0.1475 - val_tp: 70.0000 - val_fp: 1042.0000 - val_tn: 44447.0000 - val_fn: 10.0000 - val_accuracy: 0.9769 - val_precision: 0.0629 - val_recall: 0.8750 - val_auc: 0.9724 Epoch 18/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.2110 - tp: 9788.9524 - fp: 651.8571 - tn: 10585.9524 - fn: 1403.7143 - accuracy: 0.9087 - precision: 0.9392 - recall: 0.8742 - auc: 0.9716 - val_loss: 0.1401 - val_tp: 70.0000 - val_fp: 1020.0000 - val_tn: 44469.0000 - val_fn: 10.0000 - val_accuracy: 0.9774 - val_precision: 0.0642 - val_recall: 0.8750 - val_auc: 0.9718 Epoch 19/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2035 - tp: 9779.1905 - fp: 619.4286 - tn: 10699.0952 - fn: 1332.7619 - accuracy: 0.9117 - precision: 0.9402 - recall: 0.8784 - auc: 0.9733 - val_loss: 0.1340 - val_tp: 70.0000 - val_fp: 1010.0000 - val_tn: 44479.0000 - val_fn: 10.0000 - val_accuracy: 0.9776 - val_precision: 0.0648 - val_recall: 0.8750 - val_auc: 0.9721 Epoch 20/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.1993 - tp: 9782.3333 - fp: 592.7619 - tn: 10738.1905 - fn: 1317.1905 - accuracy: 0.9156 - precision: 0.9430 - recall: 0.8823 - auc: 0.9748 - val_loss: 0.1282 - val_tp: 70.0000 - val_fp: 999.0000 - val_tn: 44490.0000 - val_fn: 10.0000 - val_accuracy: 0.9779 - val_precision: 0.0655 - val_recall: 0.8750 - val_auc: 0.9721 Epoch 21/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.1995 - tp: 9862.0476 - fp: 625.4762 - tn: 10605.0952 - fn: 1337.8571 - accuracy: 0.9119 - precision: 0.9394 - recall: 0.8805 - auc: 0.9747 - val_loss: 0.1228 - val_tp: 70.0000 - val_fp: 965.0000 - val_tn: 44524.0000 - val_fn: 10.0000 - val_accuracy: 0.9786 - val_precision: 0.0676 - val_recall: 0.8750 - val_auc: 0.9718 Epoch 22/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.1936 - tp: 9804.0476 - fp: 582.0952 - tn: 10755.0000 - fn: 1289.3333 - accuracy: 0.9168 - precision: 0.9433 - recall: 0.8842 - auc: 0.9764 - val_loss: 0.1179 - val_tp: 70.0000 - val_fp: 944.0000 - val_tn: 44545.0000 - val_fn: 10.0000 - val_accuracy: 0.9791 - val_precision: 0.0690 - val_recall: 0.8750 - val_auc: 0.9713 Epoch 23/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.1889 - tp: 9952.6667 - fp: 552.3333 - tn: 10589.3810 - fn: 1336.0952 - accuracy: 0.9155 - precision: 0.9472 - recall: 0.8812 - auc: 0.9774 - val_loss: 0.1146 - val_tp: 70.0000 - val_fp: 942.0000 - val_tn: 44547.0000 - val_fn: 10.0000 - val_accuracy: 0.9791 - val_precision: 0.0692 - val_recall: 0.8750 - val_auc: 0.9715 Epoch 24/1000 20/20 [==============================] - 1s 33ms/step - loss: 0.1828 - tp: 9880.4286 - fp: 540.2381 - tn: 10753.0000 - fn: 1256.8095 - accuracy: 0.9207 - precision: 0.9483 - recall: 0.8883 - auc: 0.9787 - val_loss: 0.1116 - val_tp: 71.0000 - val_fp: 961.0000 - val_tn: 44528.0000 - val_fn: 9.0000 - val_accuracy: 0.9787 - val_precision: 0.0688 - val_recall: 0.8875 - val_auc: 0.9712 Epoch 25/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.1829 - tp: 9871.3333 - fp: 555.4286 - tn: 10747.9048 - fn: 1255.8095 - accuracy: 0.9200 - precision: 0.9470 - recall: 0.8879 - auc: 0.9791 - val_loss: 0.1082 - val_tp: 71.0000 - val_fp: 953.0000 - val_tn: 44536.0000 - val_fn: 9.0000 - val_accuracy: 0.9789 - val_precision: 0.0693 - val_recall: 0.8875 - val_auc: 0.9713 Restoring model weights from the end of the best epoch. Epoch 00025: early stopping
Vuelva a verificar el historial de entrenamiento
plot_metrics(resampled_history)
Evaluar métricas
train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_resampled)
loss : 0.16524264216423035 tp : 91.0 fp : 1376.0 tn : 55491.0 fn : 4.0 accuracy : 0.9757733345031738 precision : 0.06203135475516319 recall : 0.9578947424888611 auc : 0.9829339385032654 Legitimate Transactions Detected (True Negatives): 55491 Legitimate Transactions Incorrectly Detected (False Positives): 1376 Fraudulent Transactions Missed (False Negatives): 4 Fraudulent Transactions Detected (True Positives): 91 Total Fraudulent Transactions: 95
Trazar la República de China
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa0682848d0>
Aplicando este tutorial a su problema
La clasificación de datos desequilibrada es una tarea intrínsecamente difícil, ya que hay muy pocas muestras de las que aprender. Siempre debe comenzar con los datos primero y hacer todo lo posible para recopilar tantas muestras como sea posible y pensar en profundidad sobre qué características pueden ser relevantes para que el modelo pueda aprovechar al máximo su clase minoritaria. En algún momento, su modelo puede tener dificultades para mejorar y producir los resultados que desea, por lo que es importante tener en cuenta el contexto de su problema y las compensaciones entre los diferentes tipos de errores.