Cette page a été traduite par l'API Cloud Translation.
Switch to English

Classification des données déséquilibrés

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Bloc - notes Télécharger

Ce tutoriel montre comment classer un ensemble de données très déséquilibrés dans lequel le nombre d'exemples dans une classe plus nombreuse considérablement les exemples dans un autre. Vous travaillerez avec la carte de crédit de détection des fraudes ensemble de données hébergé sur Kaggle. Le but est de détecter un simple 492 transactions frauduleuses de 284,807 transactions au total. Vous utiliserez Keras pour définir le modèle et poids de la classe pour aider le modèle à apprendre à partir des données déséquilibrées. .

Ce tutoriel contient le code complet à:

  • Charger un fichier CSV à l'aide Pandas.
  • Créer train, validation et jeux de tests.
  • Définir et former un modèle en utilisant Keras (y compris les poids de réglage de la classe).
  • Évaluer le modèle en utilisant différents paramètres (y compris la précision et le rappel).
  • Essayez des techniques communes pour traiter les données déséquilibrées comme:
    • pondération classe
    • suréchantillonnage

Installer

 import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
 
 mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
 

le traitement et l'exploration des données

Télécharger le jeu de données de carte de crédit Kaggle fraude

Pandas est une bibliothèque Python avec de nombreux utilitaires utiles pour le chargement et le travail avec des données structurées et peuvent être utilisées pour télécharger CSVs dans une trame de données.

 file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
 
 raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
 

Examiner le déséquilibre de l'étiquette de classe

Le regard de déposons sur le déséquilibre du jeu de données:

 neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
 
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)


Cela montre la petite fraction des échantillons positifs.

Propre, divisé et normaliser les données

Les données brutes a quelques problèmes. D' abord les Time et Amount colonnes sont trop variables pour utiliser directement. Laissez tomber le Time colonne (car il est pas clair ce que cela signifie) et prendre le journal de la Amount colonne pour réduire sa gamme.

 cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps=0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)
 

Diviser l'ensemble de données en train, validation et jeux de tests. L'ensemble de validation est utilisé au cours de l'ajustement du modèle pour évaluer la perte et des mesures, mais le modèle ne convient pas avec ces données. L'ensemble de test est complètement utilisé pendant la phase de formation et sert uniquement à la fin pour évaluer dans quelle mesure le modèle se généralise à de nouvelles données. Ceci est particulièrement important avec des ensembles de données déséquilibrées où surapprentissage est une préoccupation importante du manque de données de formation.

 # Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)
 

Normaliser l'entrée à l'aide de la présente StandardScaler de sklearn. Cela définira la moyenne à 0 et l'écart type à 1.

 scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)

 
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

Regardez la distribution des données

Suivant comparer les distributions des exemples positifs et négatifs sur quelques caractéristiques. Les bonnes questions à vous poser à ce stade sont les suivants:

  • Est-ce que ces distributions sens?
    • Oui. Vous avez normalisé l'entrée et ce sont surtout concentrées dans les +/- 2 plage.
  • Pouvez-vous voir la différence entre les distributions?
    • Oui les exemples positifs contiennent un taux beaucoup plus élevé de valeurs extrêmes.
 pos_df = pd.DataFrame(train_features[ bool_train_labels], columns = train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns = train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
_ = plt.suptitle("Negative distribution")
 

.png

.png

Définir le modèle et les mesures

Définir une fonction qui crée un simple réseau de neurones à une couche cachée connectée densly, un décrochage couche pour réduire surajustement et une couche sigmoïde de sortie qui renvoie la probabilité d'une transaction frauduleuse étant:

 METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics = METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model
 

Présentation des mesures utiles

Notez qu'il ya quelques mesures définies ci-dessus qui peuvent être calculées par le modèle qui sera utile lors de l'évaluation de la performance.

  • Les faux négatifs et les faux positifs sont des échantillons qui ont été mal classés
  • Les vrais négatifs et vrais positifs sont des échantillons qui ont été correctement classés
  • La précision est le pourcentage d'exemples correctement classés> $ \ frac {\ texte {true}} échantillons {\ texte {échantillons au total}} $
  • La précision est le pourcentage de points positifs prévus qui ont été correctement classés> $ \ frac {\ texte {vrais positifs}} {\ texte {vrais positifs + faux positifs}} $
  • Le rappel est le pourcentage de points positifs réels qui ont été correctement classés> $ \ frac {\ texte {vrais positifs}} {\ text {vrais positifs + faux négatifs}} $
  • AUC se réfère à l'aire sous la courbe d'une courbe Receiver Operating Characteristic (ROC-AUC). Cette mesure est égale à la probabilité qu'un classificateur classer un échantillon positif aléatoire supérieur à un échantillon aléatoire négatif.

Lire la suite:

modèle de base

Construire le modèle

Maintenant, créer et former votre modèle en utilisant la fonction qui a été défini précédemment. Notez que le modèle est en forme avec un plus grand que la taille du lot par défaut de 2048, il est important de veiller à ce que chaque lot a une bonne chance de contenir quelques échantillons positifs. Si la taille du lot était trop petite, ils auraient probablement pas de transactions frauduleuses à apprendre.

 EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
 
 model = make_model()
model.summary()
 
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________

Essai de fonctionnement du modèle:

 model.predict(train_features[:10])
 
array([[0.5788107 ],
       [0.44979692],
       [0.5427961 ],
       [0.5985188 ],
       [0.7758075 ],
       [0.3417888 ],
       [0.39359283],
       [0.5399953 ],
       [0.3551327 ],
       [0.47230086]], dtype=float32)

Facultatif: Définissez le biais initial correct.

Ces estimations initiales ne sont pas grandes. Vous connaissez le jeu de données est déséquilibrée. Définir le biais de la couche de sortie pour indiquer que (Voir: Une recette pour la formation des réseaux de neurones: « Init bien » ). Cette aide peut avec la convergence initiale.

Avec l'initialisation de la perte biais par défaut devrait être d' environ math.log(2) = 0.69314

 results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
 
Loss: 0.7817

Le biais correct ensemble peut être dérivé de:

$$ = p_0 pos / (pos + neg) = 1 / (1 + e ^ {- b_0}) $$
$$ b_0 = -log_e (1 / p_0 - 1) $$
$$ b_0 = log_e (pos / neg) $$
 initial_bias = np.log([pos/neg])
initial_bias
 
array([-6.35935934])

Réglez que le parti pris initial, et le modèle donnera des estimations initiales beaucoup plus raisonnable.

Il devrait être près de : pos/total = 0.0018

 model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])
 
array([[0.00093563],
       [0.00187903],
       [0.00109238],
       [0.00117128],
       [0.00134988],
       [0.00090826],
       [0.00099455],
       [0.00154405],
       [0.00100204],
       [0.0004291 ]], dtype=float32)

Avec cette initialisation la perte initiale doit être d'environ:

$$ - p_0log (p_0) - (1-p_0) log (1-p_0) = 0,01317 $$
 results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
 
Loss: 0.0146

Cette perte initiale est d'environ 50 fois moins que si aurait été avec l'initialisation naïve.

Ainsi, le modèle n'a pas besoin de passer les premières époques en train d'apprendre que des exemples positifs sont peu probables. Cela rend également plus facile à lire des parcelles de la perte au cours de la formation.

Point de contrôle des poids initiaux

Pour les différentes pistes de formation plus comparables, garder les poids de ce modèle initial dans un fichier de point de contrôle, et les charger dans chaque modèle avant la formation.

 initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights')
model.save_weights(initial_weights)
 

Assurez-vous que le correctif de polarisation aide

Avant de poursuivre, confirmer rapidement que l'initialisation de polarisation fait attention a aidé.

Former le modèle de 20 époques, avec et sans cette initialisation attention, et de comparer les pertes:

 model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
 
 model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
 
 def plot_loss(history, label, n):
  # Use a log scale to show the wide range of values.
  plt.semilogy(history.epoch,  history.history['loss'],
               color=colors[n], label='Train '+label)
  plt.semilogy(history.epoch,  history.history['val_loss'],
          color=colors[n], label='Val '+label,
          linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  
  plt.legend()
 
 plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)
 

.png

La figure ci-dessus montre clairement: En termes de perte de validation, sur ce problème, cette initialisation attention donne un net avantage.

Former le modèle

 model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels))
 
Epoch 1/100
90/90 [==============================] - 1s 13ms/step - loss: 0.0112 - tp: 100.0000 - fp: 25.0000 - tn: 227419.0000 - fn: 301.0000 - accuracy: 0.9986 - precision: 0.8000 - recall: 0.2494 - auc: 0.7615 - val_loss: 0.0067 - val_tp: 15.0000 - val_fp: 2.0000 - val_tn: 45480.0000 - val_fn: 72.0000 - val_accuracy: 0.9984 - val_precision: 0.8824 - val_recall: 0.1724 - val_auc: 0.9077
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0075 - tp: 108.0000 - fp: 24.0000 - tn: 181938.0000 - fn: 206.0000 - accuracy: 0.9987 - precision: 0.8182 - recall: 0.3439 - auc: 0.8491 - val_loss: 0.0046 - val_tp: 45.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 42.0000 - val_accuracy: 0.9989 - val_precision: 0.8824 - val_recall: 0.5172 - val_auc: 0.9308
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0065 - tp: 138.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 176.0000 - accuracy: 0.9989 - precision: 0.8364 - recall: 0.4395 - auc: 0.8567 - val_loss: 0.0040 - val_tp: 54.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 33.0000 - val_accuracy: 0.9991 - val_precision: 0.8852 - val_recall: 0.6207 - val_auc: 0.9365
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0060 - tp: 154.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 160.0000 - accuracy: 0.9989 - precision: 0.8235 - recall: 0.4904 - auc: 0.8848 - val_loss: 0.0037 - val_tp: 61.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 26.0000 - val_accuracy: 0.9993 - val_precision: 0.8841 - val_recall: 0.7011 - val_auc: 0.9422
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0057 - tp: 157.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 157.0000 - accuracy: 0.9989 - precision: 0.8135 - recall: 0.5000 - auc: 0.8982 - val_loss: 0.0035 - val_tp: 62.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8857 - val_recall: 0.7126 - val_auc: 0.9422
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0057 - tp: 152.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 162.0000 - accuracy: 0.9989 - precision: 0.8261 - recall: 0.4841 - auc: 0.8934 - val_loss: 0.0033 - val_tp: 65.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8904 - val_recall: 0.7471 - val_auc: 0.9479
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0052 - tp: 174.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 140.0000 - accuracy: 0.9991 - precision: 0.8529 - recall: 0.5541 - auc: 0.8983 - val_loss: 0.0032 - val_tp: 66.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8919 - val_recall: 0.7586 - val_auc: 0.9479
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0054 - tp: 161.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 153.0000 - accuracy: 0.9990 - precision: 0.8342 - recall: 0.5127 - auc: 0.8983 - val_loss: 0.0031 - val_tp: 66.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8919 - val_recall: 0.7586 - val_auc: 0.9479
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 167.0000 - fp: 37.0000 - tn: 181925.0000 - fn: 147.0000 - accuracy: 0.9990 - precision: 0.8186 - recall: 0.5318 - auc: 0.9064 - val_loss: 0.0030 - val_tp: 65.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8904 - val_recall: 0.7471 - val_auc: 0.9479
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 156.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 158.0000 - accuracy: 0.9989 - precision: 0.8211 - recall: 0.4968 - auc: 0.9046 - val_loss: 0.0029 - val_tp: 67.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8933 - val_recall: 0.7701 - val_auc: 0.9479
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0048 - tp: 165.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8376 - recall: 0.5255 - auc: 0.9063 - val_loss: 0.0029 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9479
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0051 - tp: 165.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8250 - recall: 0.5255 - auc: 0.9110 - val_loss: 0.0028 - val_tp: 67.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8933 - val_recall: 0.7701 - val_auc: 0.9480
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 157.0000 - fp: 29.0000 - tn: 181933.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8441 - recall: 0.5000 - auc: 0.9031 - val_loss: 0.0028 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9479
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 160.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 154.0000 - accuracy: 0.9990 - precision: 0.8205 - recall: 0.5096 - auc: 0.8934 - val_loss: 0.0027 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9479
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0049 - tp: 168.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8235 - recall: 0.5350 - auc: 0.9031 - val_loss: 0.0027 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9479
Epoch 16/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0046 - tp: 169.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 145.0000 - accuracy: 0.9990 - precision: 0.8492 - recall: 0.5382 - auc: 0.9143 - val_loss: 0.0027 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 181.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8498 - recall: 0.5764 - auc: 0.9144 - val_loss: 0.0027 - val_tp: 70.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.8974 - val_recall: 0.8046 - val_auc: 0.9537
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 181.0000 - fp: 29.0000 - tn: 181933.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8619 - recall: 0.5764 - auc: 0.9112 - val_loss: 0.0026 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 172.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8431 - recall: 0.5478 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 20/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 177.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 137.0000 - accuracy: 0.9991 - precision: 0.8349 - recall: 0.5637 - auc: 0.9128 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 21/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 176.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8462 - recall: 0.5605 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.9167 - val_recall: 0.7586 - val_auc: 0.9537
Epoch 22/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 163.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 151.0000 - accuracy: 0.9990 - precision: 0.8316 - recall: 0.5191 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 23/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 183.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8281 - recall: 0.5828 - auc: 0.9113 - val_loss: 0.0026 - val_tp: 66.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.9041 - val_recall: 0.7586 - val_auc: 0.9537
Epoch 24/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 168.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8400 - recall: 0.5350 - auc: 0.9128 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 25/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 179.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 135.0000 - accuracy: 0.9991 - precision: 0.8483 - recall: 0.5701 - auc: 0.9161 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 26/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 173.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 141.0000 - accuracy: 0.9990 - precision: 0.8199 - recall: 0.5510 - auc: 0.9208 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 27/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 172.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8431 - recall: 0.5478 - auc: 0.9081 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 28/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 181.0000 - fp: 39.0000 - tn: 181923.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8227 - recall: 0.5764 - auc: 0.9193 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9189 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 29/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 177.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 137.0000 - accuracy: 0.9990 - precision: 0.8233 - recall: 0.5637 - auc: 0.9305 - val_loss: 0.0025 - val_tp: 67.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.9054 - val_recall: 0.7701 - val_auc: 0.9538
Epoch 30/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 168.0000 - fp: 31.0000 - tn: 181931.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8442 - recall: 0.5350 - auc: 0.9161 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 31/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 172.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8309 - recall: 0.5478 - auc: 0.9176 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 32/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 188.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8507 - recall: 0.5987 - auc: 0.9162 - val_loss: 0.0025 - val_tp: 70.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9091 - val_recall: 0.8046 - val_auc: 0.9538
Epoch 33/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 184.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8720 - recall: 0.5860 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 72.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.8276 - val_auc: 0.9537
Epoch 34/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 185.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 129.0000 - accuracy: 0.9991 - precision: 0.8486 - recall: 0.5892 - auc: 0.9273 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 35/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 178.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8318 - recall: 0.5669 - auc: 0.9160 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 36/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 171.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 143.0000 - accuracy: 0.9990 - precision: 0.8382 - recall: 0.5446 - auc: 0.9192 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 37/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 189.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 125.0000 - accuracy: 0.9991 - precision: 0.8438 - recall: 0.6019 - auc: 0.9242 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 38/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 185.0000 - fp: 25.0000 - tn: 181937.0000 - fn: 129.0000 - accuracy: 0.9992 - precision: 0.8810 - recall: 0.5892 - auc: 0.9176 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 39/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 181.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8380 - recall: 0.5764 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9189 - val_recall: 0.7816 - val_auc: 0.9538
Epoch 40/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 175.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 139.0000 - accuracy: 0.9991 - precision: 0.8537 - recall: 0.5573 - auc: 0.9209 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 41/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 180.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 134.0000 - accuracy: 0.9991 - precision: 0.8491 - recall: 0.5732 - auc: 0.9320 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 42/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 188.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8468 - recall: 0.5987 - auc: 0.9209 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 43/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 176.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8421 - recall: 0.5605 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 44/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 172.0000 - fp: 37.0000 - tn: 181925.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8230 - recall: 0.5478 - auc: 0.9129 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 45/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 175.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 139.0000 - accuracy: 0.9990 - precision: 0.8294 - recall: 0.5573 - auc: 0.9368 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 46/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 176.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8421 - recall: 0.5605 - auc: 0.9240 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 47/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 178.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8683 - recall: 0.5669 - auc: 0.9273 - val_loss: 0.0025 - val_tp: 72.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.8276 - val_auc: 0.9537
Epoch 48/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 198.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8534 - recall: 0.6306 - auc: 0.9256 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 5.0000 - val_tn: 45477.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9315 - val_recall: 0.7816 - val_auc: 0.9538
Epoch 49/100
85/90 [===========================>..] - ETA: 0s - loss: 0.0043 - tp: 162.0000 - fp: 29.0000 - tn: 173750.0000 - fn: 139.0000 - accuracy: 0.9990 - precision: 0.8482 - recall: 0.5382 - auc: 0.9157Restoring model weights from the end of the best epoch.
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 171.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 143.0000 - accuracy: 0.9991 - precision: 0.8507 - recall: 0.5446 - auc: 0.9191 - val_loss: 0.0024 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 00049: early stopping

Vérification de l'historique de la formation

Dans cette section, vous produirez des parcelles de la précision de votre modèle et la perte sur l'ensemble de la formation et de validation. Ceux - ci sont utiles pour vérifier surapprentissage, que vous pouvez en savoir plus sur ce tutoriel .

De plus, vous pouvez produire ces parcelles pour l'une des mesures que vous avez créé ci-dessus. Les faux négatifs sont inclus à titre d'exemple.

 def plot_metrics(history):
  metrics =  ['loss', 'auc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch,  history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()

 
 plot_metrics(baseline_history)
 

.png

évaluer les mesures

Vous pouvez utiliser une matrice de confusion pour résumer les réelles par rapport aux étiquettes prédit où l'axe X est l'étiquette prédite et l'axe Y est l'étiquette réelle.

 train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
 
 def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))
 

Évaluez votre modèle sur l'ensemble de données de test et afficher les résultats pour les paramètres que vous avez créés ci-dessus.

 baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
 
loss :  0.002310588490217924
tp :  69.0
fp :  5.0
tn :  56866.0
fn :  22.0
accuracy :  0.9995260238647461
precision :  0.9324324131011963
recall :  0.7582417726516724
auc :  0.9557874202728271

Legitimate Transactions Detected (True Negatives):  56866
Legitimate Transactions Incorrectly Detected (False Positives):  5
Fraudulent Transactions Missed (False Negatives):  22
Fraudulent Transactions Detected (True Positives):  69
Total Fraudulent Transactions:  91

.png

Si le modèle avait prédit tout parfaitement, ce serait une matrice diagonale où les valeurs au large de la diagonale principale, indiquant des prédictions erronées, seraient nuls. Dans ce cas, la matrice montre que vous avez relativement peu de faux positifs, ce qui signifie qu'il y avait relativement peu de transactions légitimes qui ont été incorrectement signalées. Cependant, vous voudrez probablement avoir encore moins de faux négatifs malgré le coût de l'augmentation du nombre de faux positifs. Ce compromis peut être préférable, car les faux négatifs permettrait des transactions frauduleuses à traverser, alors que les faux positifs peuvent causer un e-mail à envoyer à un client pour leur demander de vérifier leur activité de carte.

Tracer la ROC

Maintenant , tracer le ROC . Cette parcelle est utile car il montre, un coup d'oeil, la gamme de la performance du modèle peut atteindre tout en réglant le seuil de sortie.

 def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
 
 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa50c5adef0>

.png

Il semble que la précision est relativement élevé, mais le rappel et l'aire sous la courbe ROC (AUC) ne sont pas aussi haut que vous pourriez aimer. Classificateurs souvent des défis à relever en essayant de maximiser la précision et de rappel, ce qui est particulièrement vrai lorsque vous travaillez avec des ensembles de données déséquilibrés. Il est important de prendre en compte les coûts des différents types d'erreurs dans le contexte du problème que vous aimez. Dans cet exemple, un faux négatif (une transaction frauduleuse est manquée) peut avoir un coût financier, tandis qu'un faux positif (une transaction est mal signalée comme frauduleuse) peut diminuer le bonheur de l'utilisateur.

poids de classe

Calculer le poids de la classe

L'objectif est d'identifier les transactions frauduleuses, mais vous n'avez pas beaucoup de ces échantillons positifs à travailler avec, de sorte que vous voulez avoir le classificateur beaucoup de poids les quelques exemples qui sont disponibles. Vous pouvez le faire en passant des poids KERAS pour chaque classe par un paramètre. Ceux-ci entraîneront le modèle à « accorder plus d'attention » à des exemples d'une classe sous-représenté.

 # Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
 
Weight for class 0: 0.50
Weight for class 1: 289.44

Former un modèle avec des poids de classe

Maintenant, essayez une nouvelle formation et l'évaluation du modèle avec des poids de classe pour voir comment cela affecte les prévisions.

 weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight) 
 
Epoch 1/100
90/90 [==============================] - 1s 15ms/step - loss: 2.5149 - tp: 105.0000 - fp: 66.0000 - tn: 238767.0000 - fn: 300.0000 - accuracy: 0.9985 - precision: 0.6140 - recall: 0.2593 - auc: 0.7803 - val_loss: 0.0067 - val_tp: 25.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 62.0000 - val_accuracy: 0.9985 - val_precision: 0.8065 - val_recall: 0.2874 - val_auc: 0.9211
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 1.2482 - tp: 145.0000 - fp: 124.0000 - tn: 181838.0000 - fn: 169.0000 - accuracy: 0.9984 - precision: 0.5390 - recall: 0.4618 - auc: 0.8560 - val_loss: 0.0062 - val_tp: 68.0000 - val_fp: 12.0000 - val_tn: 45470.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8500 - val_recall: 0.7816 - val_auc: 0.9408
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.8972 - tp: 177.0000 - fp: 237.0000 - tn: 181725.0000 - fn: 137.0000 - accuracy: 0.9979 - precision: 0.4275 - recall: 0.5637 - auc: 0.8876 - val_loss: 0.0079 - val_tp: 73.0000 - val_fp: 16.0000 - val_tn: 45466.0000 - val_fn: 14.0000 - val_accuracy: 0.9993 - val_precision: 0.8202 - val_recall: 0.8391 - val_auc: 0.9518
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6983 - tp: 210.0000 - fp: 387.0000 - tn: 181575.0000 - fn: 104.0000 - accuracy: 0.9973 - precision: 0.3518 - recall: 0.6688 - auc: 0.9028 - val_loss: 0.0098 - val_tp: 74.0000 - val_fp: 19.0000 - val_tn: 45463.0000 - val_fn: 13.0000 - val_accuracy: 0.9993 - val_precision: 0.7957 - val_recall: 0.8506 - val_auc: 0.9600
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6417 - tp: 220.0000 - fp: 583.0000 - tn: 181379.0000 - fn: 94.0000 - accuracy: 0.9963 - precision: 0.2740 - recall: 0.7006 - auc: 0.9084 - val_loss: 0.0119 - val_tp: 74.0000 - val_fp: 25.0000 - val_tn: 45457.0000 - val_fn: 13.0000 - val_accuracy: 0.9992 - val_precision: 0.7475 - val_recall: 0.8506 - val_auc: 0.9777
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5846 - tp: 232.0000 - fp: 977.0000 - tn: 180985.0000 - fn: 82.0000 - accuracy: 0.9942 - precision: 0.1919 - recall: 0.7389 - auc: 0.9048 - val_loss: 0.0148 - val_tp: 74.0000 - val_fp: 34.0000 - val_tn: 45448.0000 - val_fn: 13.0000 - val_accuracy: 0.9990 - val_precision: 0.6852 - val_recall: 0.8506 - val_auc: 0.9802
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5404 - tp: 234.0000 - fp: 1464.0000 - tn: 180498.0000 - fn: 80.0000 - accuracy: 0.9915 - precision: 0.1378 - recall: 0.7452 - auc: 0.9190 - val_loss: 0.0183 - val_tp: 74.0000 - val_fp: 50.0000 - val_tn: 45432.0000 - val_fn: 13.0000 - val_accuracy: 0.9986 - val_precision: 0.5968 - val_recall: 0.8506 - val_auc: 0.9823
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4714 - tp: 241.0000 - fp: 1862.0000 - tn: 180100.0000 - fn: 73.0000 - accuracy: 0.9894 - precision: 0.1146 - recall: 0.7675 - auc: 0.9252 - val_loss: 0.0225 - val_tp: 76.0000 - val_fp: 84.0000 - val_tn: 45398.0000 - val_fn: 11.0000 - val_accuracy: 0.9979 - val_precision: 0.4750 - val_recall: 0.8736 - val_auc: 0.9851
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4329 - tp: 247.0000 - fp: 2508.0000 - tn: 179454.0000 - fn: 67.0000 - accuracy: 0.9859 - precision: 0.0897 - recall: 0.7866 - auc: 0.9345 - val_loss: 0.0282 - val_tp: 76.0000 - val_fp: 170.0000 - val_tn: 45312.0000 - val_fn: 11.0000 - val_accuracy: 0.9960 - val_precision: 0.3089 - val_recall: 0.8736 - val_auc: 0.9873
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4467 - tp: 249.0000 - fp: 3175.0000 - tn: 178787.0000 - fn: 65.0000 - accuracy: 0.9822 - precision: 0.0727 - recall: 0.7930 - auc: 0.9210 - val_loss: 0.0341 - val_tp: 78.0000 - val_fp: 282.0000 - val_tn: 45200.0000 - val_fn: 9.0000 - val_accuracy: 0.9936 - val_precision: 0.2167 - val_recall: 0.8966 - val_auc: 0.9881
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3947 - tp: 260.0000 - fp: 3569.0000 - tn: 178393.0000 - fn: 54.0000 - accuracy: 0.9801 - precision: 0.0679 - recall: 0.8280 - auc: 0.9290 - val_loss: 0.0394 - val_tp: 78.0000 - val_fp: 346.0000 - val_tn: 45136.0000 - val_fn: 9.0000 - val_accuracy: 0.9922 - val_precision: 0.1840 - val_recall: 0.8966 - val_auc: 0.9877
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3694 - tp: 257.0000 - fp: 4294.0000 - tn: 177668.0000 - fn: 57.0000 - accuracy: 0.9761 - precision: 0.0565 - recall: 0.8185 - auc: 0.9418 - val_loss: 0.0473 - val_tp: 78.0000 - val_fp: 504.0000 - val_tn: 44978.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 0.1340 - val_recall: 0.8966 - val_auc: 0.9879
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3479 - tp: 262.0000 - fp: 4886.0000 - tn: 177076.0000 - fn: 52.0000 - accuracy: 0.9729 - precision: 0.0509 - recall: 0.8344 - auc: 0.9403 - val_loss: 0.0539 - val_tp: 78.0000 - val_fp: 586.0000 - val_tn: 44896.0000 - val_fn: 9.0000 - val_accuracy: 0.9869 - val_precision: 0.1175 - val_recall: 0.8966 - val_auc: 0.9881
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3653 - tp: 263.0000 - fp: 5360.0000 - tn: 176602.0000 - fn: 51.0000 - accuracy: 0.9703 - precision: 0.0468 - recall: 0.8376 - auc: 0.9370 - val_loss: 0.0610 - val_tp: 78.0000 - val_fp: 664.0000 - val_tn: 44818.0000 - val_fn: 9.0000 - val_accuracy: 0.9852 - val_precision: 0.1051 - val_recall: 0.8966 - val_auc: 0.9876
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3673 - tp: 262.0000 - fp: 5820.0000 - tn: 176142.0000 - fn: 52.0000 - accuracy: 0.9678 - precision: 0.0431 - recall: 0.8344 - auc: 0.9316 - val_loss: 0.0658 - val_tp: 78.0000 - val_fp: 715.0000 - val_tn: 44767.0000 - val_fn: 9.0000 - val_accuracy: 0.9841 - val_precision: 0.0984 - val_recall: 0.8966 - val_auc: 0.9877
Epoch 16/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3228 - tp: 262.0000 - fp: 6230.0000 - tn: 175732.0000 - fn: 52.0000 - accuracy: 0.9655 - precision: 0.0404 - recall: 0.8344 - auc: 0.9445 - val_loss: 0.0716 - val_tp: 79.0000 - val_fp: 805.0000 - val_tn: 44677.0000 - val_fn: 8.0000 - val_accuracy: 0.9822 - val_precision: 0.0894 - val_recall: 0.9080 - val_auc: 0.9877
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3299 - tp: 268.0000 - fp: 6572.0000 - tn: 175390.0000 - fn: 46.0000 - accuracy: 0.9637 - precision: 0.0392 - recall: 0.8535 - auc: 0.9423 - val_loss: 0.0757 - val_tp: 81.0000 - val_fp: 846.0000 - val_tn: 44636.0000 - val_fn: 6.0000 - val_accuracy: 0.9813 - val_precision: 0.0874 - val_recall: 0.9310 - val_auc: 0.9878
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2522 - tp: 276.0000 - fp: 6934.0000 - tn: 175028.0000 - fn: 38.0000 - accuracy: 0.9618 - precision: 0.0383 - recall: 0.8790 - auc: 0.9610 - val_loss: 0.0779 - val_tp: 81.0000 - val_fp: 874.0000 - val_tn: 44608.0000 - val_fn: 6.0000 - val_accuracy: 0.9807 - val_precision: 0.0848 - val_recall: 0.9310 - val_auc: 0.9877
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3607 - tp: 264.0000 - fp: 6790.0000 - tn: 175172.0000 - fn: 50.0000 - accuracy: 0.9625 - precision: 0.0374 - recall: 0.8408 - auc: 0.9303 - val_loss: 0.0781 - val_tp: 81.0000 - val_fp: 865.0000 - val_tn: 44617.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0856 - val_recall: 0.9310 - val_auc: 0.9879
Epoch 20/100
89/90 [============================>.] - ETA: 0s - loss: 0.2977 - tp: 269.0000 - fp: 6769.0000 - tn: 175189.0000 - fn: 45.0000 - accuracy: 0.9626 - precision: 0.0382 - recall: 0.8567 - auc: 0.9488Restoring model weights from the end of the best epoch.
90/90 [==============================] - 1s 6ms/step - loss: 0.2977 - tp: 269.0000 - fp: 6769.0000 - tn: 175193.0000 - fn: 45.0000 - accuracy: 0.9626 - precision: 0.0382 - recall: 0.8567 - auc: 0.9488 - val_loss: 0.0780 - val_tp: 81.0000 - val_fp: 853.0000 - val_tn: 44629.0000 - val_fn: 6.0000 - val_accuracy: 0.9811 - val_precision: 0.0867 - val_recall: 0.9310 - val_auc: 0.9879
Epoch 00020: early stopping

Vérification de l'historique de la formation

 plot_metrics(weighted_history)
 

.png

évaluer les mesures

 train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
 
 weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
 
loss :  0.03226418048143387
tp :  82.0
fp :  352.0
tn :  56519.0
fn :  9.0
accuracy :  0.993662416934967
precision :  0.18894009292125702
recall :  0.901098906993866
auc :  0.9671803712844849

Legitimate Transactions Detected (True Negatives):  56519
Legitimate Transactions Incorrectly Detected (False Positives):  352
Fraudulent Transactions Missed (False Negatives):  9
Fraudulent Transactions Detected (True Positives):  82
Total Fraudulent Transactions:  91

.png

Ici, vous pouvez voir que avec des poids de classe l'exactitude et la précision sont plus faibles, car il y a plus de faux positifs, mais à l'inverse le rappel et l'ASC sont plus élevés parce que le modèle a également plus vrais positifs. En dépit d'une moindre précision, ce modèle a un rappel plus élevé (et identifie les transactions frauduleuses plus). Bien sûr, il y a un coût pour les deux types d'erreur (vous ne voudriez pas aux utilisateurs de bugs en signalant trop de transactions légitimes comme frauduleux, non plus). Examiner soigneusement les compromis entre ces différents types d'erreurs pour votre application.

Tracer la ROC

 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa54c0729e8>

.png

suréchantillonnage

Suréchantillonnage la classe minoritaire

Une approche similaire serait de ré-échantillonner l'ensemble de données par suréchantillonnage la classe minoritaire.

 pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
 

L'utilisation NumPy

Vous pouvez équilibrer l'ensemble de données manuellement en choisissant le bon nombre d'indices aléatoires des exemples positifs:

 ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
 
(181962, 29)
 resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
 
(363924, 29)

L' utilisation tf.data

Si vous utilisez tf.data la meilleure façon de produire des exemples équilibré est de commencer par un positive et un negative ensemble de données, et de les fusionner. Voir le guide de tf.data pour plus d' exemples.

 BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)
 

Chaque jeu de données fournit (feature, label) paires:

 for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
 
Features:
 [ 0.23104754  0.83661044 -0.31875356  1.9796369   1.28403692  0.07389102
  1.03350673 -0.11568355 -1.54396817  0.88004244 -1.66944551 -0.24324391
  0.45900013  0.14583622 -2.06637388  0.42470592 -0.94489216 -0.83112221
 -1.83416278 -0.34138858  0.14130878  0.51019975  0.08224586  0.6642136
 -1.39031637 -0.42194185  0.22525572  0.28277796 -4.86369823]

Label:  1

Fusionner les deux ensemble à l' aide experimental.sample_from_datasets :

 resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
 
 for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
 
0.49609375

Pour utiliser cet ensemble de données, vous aurez besoin du nombre d'étapes par époque.

La définition de « époque » dans ce cas est moins claire. Disons que c'est le nombre de lots requis pour voir chaque exemple négatif une fois:

 resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
 
278.0

Train sur les données suréchantillonné

Maintenant, essayez la formation du modèle avec les données rééchantillonnées définies au lieu d'utiliser des poids de classe pour voir comment ces méthodes se comparent.

 resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks = [early_stopping],
    validation_data=val_ds)
 
Epoch 1/100
278/278 [==============================] - 6s 23ms/step - loss: 0.4356 - tp: 223484.0000 - fp: 51288.0000 - tn: 290777.0000 - fn: 60757.0000 - accuracy: 0.8211 - precision: 0.8133 - recall: 0.7862 - auc: 0.8933 - val_loss: 0.2172 - val_tp: 79.0000 - val_fp: 1076.0000 - val_tn: 44406.0000 - val_fn: 8.0000 - val_accuracy: 0.9762 - val_precision: 0.0684 - val_recall: 0.9080 - val_auc: 0.9792
Epoch 2/100
278/278 [==============================] - 6s 20ms/step - loss: 0.2177 - tp: 246785.0000 - fp: 12557.0000 - tn: 271871.0000 - fn: 38131.0000 - accuracy: 0.9110 - precision: 0.9516 - recall: 0.8662 - auc: 0.9686 - val_loss: 0.1226 - val_tp: 80.0000 - val_fp: 951.0000 - val_tn: 44531.0000 - val_fn: 7.0000 - val_accuracy: 0.9790 - val_precision: 0.0776 - val_recall: 0.9195 - val_auc: 0.9835
Epoch 3/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1751 - tp: 250631.0000 - fp: 9797.0000 - tn: 275174.0000 - fn: 33742.0000 - accuracy: 0.9235 - precision: 0.9624 - recall: 0.8813 - auc: 0.9810 - val_loss: 0.0940 - val_tp: 82.0000 - val_fp: 966.0000 - val_tn: 44516.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0782 - val_recall: 0.9425 - val_auc: 0.9836
Epoch 4/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1532 - tp: 254169.0000 - fp: 9171.0000 - tn: 275694.0000 - fn: 30310.0000 - accuracy: 0.9307 - precision: 0.9652 - recall: 0.8935 - auc: 0.9861 - val_loss: 0.0802 - val_tp: 82.0000 - val_fp: 918.0000 - val_tn: 44564.0000 - val_fn: 5.0000 - val_accuracy: 0.9797 - val_precision: 0.0820 - val_recall: 0.9425 - val_auc: 0.9847
Epoch 5/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1372 - tp: 257034.0000 - fp: 9061.0000 - tn: 275758.0000 - fn: 27491.0000 - accuracy: 0.9358 - precision: 0.9659 - recall: 0.9034 - auc: 0.9892 - val_loss: 0.0720 - val_tp: 82.0000 - val_fp: 910.0000 - val_tn: 44572.0000 - val_fn: 5.0000 - val_accuracy: 0.9799 - val_precision: 0.0827 - val_recall: 0.9425 - val_auc: 0.9854
Epoch 6/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1260 - tp: 258997.0000 - fp: 9079.0000 - tn: 275819.0000 - fn: 25449.0000 - accuracy: 0.9394 - precision: 0.9661 - recall: 0.9105 - auc: 0.9911 - val_loss: 0.0666 - val_tp: 81.0000 - val_fp: 915.0000 - val_tn: 44567.0000 - val_fn: 6.0000 - val_accuracy: 0.9798 - val_precision: 0.0813 - val_recall: 0.9310 - val_auc: 0.9856
Epoch 7/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1167 - tp: 261100.0000 - fp: 9112.0000 - tn: 276180.0000 - fn: 22952.0000 - accuracy: 0.9437 - precision: 0.9663 - recall: 0.9192 - auc: 0.9925 - val_loss: 0.0623 - val_tp: 81.0000 - val_fp: 911.0000 - val_tn: 44571.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0817 - val_recall: 0.9310 - val_auc: 0.9858
Epoch 8/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1082 - tp: 263945.0000 - fp: 9428.0000 - tn: 275276.0000 - fn: 20695.0000 - accuracy: 0.9471 - precision: 0.9655 - recall: 0.9273 - auc: 0.9937 - val_loss: 0.0587 - val_tp: 81.0000 - val_fp: 910.0000 - val_tn: 44572.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0817 - val_recall: 0.9310 - val_auc: 0.9857
Epoch 9/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1014 - tp: 268108.0000 - fp: 10376.0000 - tn: 274312.0000 - fn: 16548.0000 - accuracy: 0.9527 - precision: 0.9627 - recall: 0.9419 - auc: 0.9944 - val_loss: 0.0543 - val_tp: 80.0000 - val_fp: 873.0000 - val_tn: 44609.0000 - val_fn: 7.0000 - val_accuracy: 0.9807 - val_precision: 0.0839 - val_recall: 0.9195 - val_auc: 0.9857
Epoch 10/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0951 - tp: 277520.0000 - fp: 12692.0000 - tn: 271795.0000 - fn: 7337.0000 - accuracy: 0.9648 - precision: 0.9563 - recall: 0.9742 - auc: 0.9950 - val_loss: 0.0495 - val_tp: 79.0000 - val_fp: 829.0000 - val_tn: 44653.0000 - val_fn: 8.0000 - val_accuracy: 0.9816 - val_precision: 0.0870 - val_recall: 0.9080 - val_auc: 0.9855
Epoch 11/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0895 - tp: 278865.0000 - fp: 12938.0000 - tn: 271719.0000 - fn: 5822.0000 - accuracy: 0.9670 - precision: 0.9557 - recall: 0.9795 - auc: 0.9955 - val_loss: 0.0450 - val_tp: 79.0000 - val_fp: 789.0000 - val_tn: 44693.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0910 - val_recall: 0.9080 - val_auc: 0.9859
Epoch 12/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0842 - tp: 279845.0000 - fp: 13187.0000 - tn: 272121.0000 - fn: 4191.0000 - accuracy: 0.9695 - precision: 0.9550 - recall: 0.9852 - auc: 0.9960 - val_loss: 0.0410 - val_tp: 79.0000 - val_fp: 733.0000 - val_tn: 44749.0000 - val_fn: 8.0000 - val_accuracy: 0.9837 - val_precision: 0.0973 - val_recall: 0.9080 - val_auc: 0.9813
Epoch 13/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0792 - tp: 281765.0000 - fp: 12977.0000 - tn: 271393.0000 - fn: 3209.0000 - accuracy: 0.9716 - precision: 0.9560 - recall: 0.9887 - auc: 0.9963 - val_loss: 0.0389 - val_tp: 79.0000 - val_fp: 721.0000 - val_tn: 44761.0000 - val_fn: 8.0000 - val_accuracy: 0.9840 - val_precision: 0.0988 - val_recall: 0.9080 - val_auc: 0.9814
Epoch 14/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0754 - tp: 281962.0000 - fp: 13026.0000 - tn: 272154.0000 - fn: 2202.0000 - accuracy: 0.9733 - precision: 0.9558 - recall: 0.9923 - auc: 0.9966 - val_loss: 0.0348 - val_tp: 79.0000 - val_fp: 646.0000 - val_tn: 44836.0000 - val_fn: 8.0000 - val_accuracy: 0.9856 - val_precision: 0.1090 - val_recall: 0.9080 - val_auc: 0.9763
Epoch 15/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0722 - tp: 283858.0000 - fp: 12932.0000 - tn: 271419.0000 - fn: 1135.0000 - accuracy: 0.9753 - precision: 0.9564 - recall: 0.9960 - auc: 0.9967 - val_loss: 0.0331 - val_tp: 79.0000 - val_fp: 640.0000 - val_tn: 44842.0000 - val_fn: 8.0000 - val_accuracy: 0.9858 - val_precision: 0.1099 - val_recall: 0.9080 - val_auc: 0.9714
Epoch 16/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0689 - tp: 283059.0000 - fp: 12757.0000 - tn: 273004.0000 - fn: 524.0000 - accuracy: 0.9767 - precision: 0.9569 - recall: 0.9982 - auc: 0.9970 - val_loss: 0.0308 - val_tp: 79.0000 - val_fp: 583.0000 - val_tn: 44899.0000 - val_fn: 8.0000 - val_accuracy: 0.9870 - val_precision: 0.1193 - val_recall: 0.9080 - val_auc: 0.9667
Epoch 17/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0661 - tp: 283879.0000 - fp: 12340.0000 - tn: 272779.0000 - fn: 346.0000 - accuracy: 0.9777 - precision: 0.9583 - recall: 0.9988 - auc: 0.9971 - val_loss: 0.0289 - val_tp: 79.0000 - val_fp: 542.0000 - val_tn: 44940.0000 - val_fn: 8.0000 - val_accuracy: 0.9879 - val_precision: 0.1272 - val_recall: 0.9080 - val_auc: 0.9618
Epoch 18/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0635 - tp: 284858.0000 - fp: 12157.0000 - tn: 272120.0000 - fn: 209.0000 - accuracy: 0.9783 - precision: 0.9591 - recall: 0.9993 - auc: 0.9973 - val_loss: 0.0277 - val_tp: 79.0000 - val_fp: 511.0000 - val_tn: 44971.0000 - val_fn: 8.0000 - val_accuracy: 0.9886 - val_precision: 0.1339 - val_recall: 0.9080 - val_auc: 0.9621
Epoch 19/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0620 - tp: 284459.0000 - fp: 11978.0000 - tn: 272718.0000 - fn: 189.0000 - accuracy: 0.9786 - precision: 0.9596 - recall: 0.9993 - auc: 0.9973 - val_loss: 0.0261 - val_tp: 79.0000 - val_fp: 478.0000 - val_tn: 45004.0000 - val_fn: 8.0000 - val_accuracy: 0.9893 - val_precision: 0.1418 - val_recall: 0.9080 - val_auc: 0.9624
Epoch 20/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0600 - tp: 284950.0000 - fp: 11793.0000 - tn: 272572.0000 - fn: 29.0000 - accuracy: 0.9792 - precision: 0.9603 - recall: 0.9999 - auc: 0.9974 - val_loss: 0.0252 - val_tp: 79.0000 - val_fp: 463.0000 - val_tn: 45019.0000 - val_fn: 8.0000 - val_accuracy: 0.9897 - val_precision: 0.1458 - val_recall: 0.9080 - val_auc: 0.9626
Epoch 21/100
276/278 [============================>.] - ETA: 0s - loss: 0.0581 - tp: 282210.0000 - fp: 11270.0000 - tn: 271768.0000 - fn: 0.0000e+00 - accuracy: 0.9801 - precision: 0.9616 - recall: 1.0000 - auc: 0.9975Restoring model weights from the end of the best epoch.
278/278 [==============================] - 6s 22ms/step - loss: 0.0581 - tp: 284274.0000 - fp: 11360.0000 - tn: 273710.0000 - fn: 0.0000e+00 - accuracy: 0.9800 - precision: 0.9616 - recall: 1.0000 - auc: 0.9975 - val_loss: 0.0241 - val_tp: 79.0000 - val_fp: 444.0000 - val_tn: 45038.0000 - val_fn: 8.0000 - val_accuracy: 0.9901 - val_precision: 0.1511 - val_recall: 0.9080 - val_auc: 0.9628
Epoch 00021: early stopping

Si le processus de formation examinaient l'ensemble des données sur chaque mise à jour du gradient, ce suréchantillonnage serait essentiellement identique à la pondération des classes.

Mais lors de la formation du lot sage modèle, comme vous l'avez fait ici, les données suréchantillonné fournit un signal de gradient de plus lisse: au lieu de chaque exemple positif étant représenté dans un lot avec un grand poids, ils sont présentés dans de nombreux lots différents à chaque fois avec un petit poids.

Ce signal de gradient plus lisse facilite la formation du modèle.

Vérification de l'historique de la formation

Notez que les distributions de mesures seront différentes ici, parce que les données de formation a une distribution totalement différente des données de validation et de test.

 plot_metrics(resampled_history )
 

.png

Recycler

Parce que la formation est plus facile sur les données équilibrées, la procédure de formation ci-dessus peut surajuster rapidement.

, PORTER les époques pour donner le callbacks.EarlyStopping contrôle plus fin quand arrêter la formation.

 resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch = 20,
    epochs=10*EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_ds))
 
Epoch 1/1000
20/20 [==============================] - 1s 60ms/step - loss: 1.0656 - tp: 9507.0000 - fp: 7370.0000 - tn: 58667.0000 - fn: 10985.0000 - accuracy: 0.7879 - precision: 0.5633 - recall: 0.4639 - auc: 0.8255 - val_loss: 0.5792 - val_tp: 66.0000 - val_fp: 13452.0000 - val_tn: 32030.0000 - val_fn: 21.0000 - val_accuracy: 0.7043 - val_precision: 0.0049 - val_recall: 0.7586 - val_auc: 0.7866
Epoch 2/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.6996 - tp: 13383.0000 - fp: 7208.0000 - tn: 13397.0000 - fn: 6972.0000 - accuracy: 0.6538 - precision: 0.6499 - recall: 0.6575 - auc: 0.7027 - val_loss: 0.5702 - val_tp: 76.0000 - val_fp: 12408.0000 - val_tn: 33074.0000 - val_fn: 11.0000 - val_accuracy: 0.7275 - val_precision: 0.0061 - val_recall: 0.8736 - val_auc: 0.9076
Epoch 3/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.5532 - tp: 15127.0000 - fp: 6665.0000 - tn: 14055.0000 - fn: 5113.0000 - accuracy: 0.7125 - precision: 0.6942 - recall: 0.7474 - auc: 0.7952 - val_loss: 0.5335 - val_tp: 79.0000 - val_fp: 9006.0000 - val_tn: 36476.0000 - val_fn: 8.0000 - val_accuracy: 0.8022 - val_precision: 0.0087 - val_recall: 0.9080 - val_auc: 0.9408
Epoch 4/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.4738 - tp: 16061.0000 - fp: 5669.0000 - tn: 14890.0000 - fn: 4340.0000 - accuracy: 0.7556 - precision: 0.7391 - recall: 0.7873 - auc: 0.8495 - val_loss: 0.4883 - val_tp: 78.0000 - val_fp: 5756.0000 - val_tn: 39726.0000 - val_fn: 9.0000 - val_accuracy: 0.8735 - val_precision: 0.0134 - val_recall: 0.8966 - val_auc: 0.9489
Epoch 5/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.4266 - tp: 16612.0000 - fp: 4719.0000 - tn: 15715.0000 - fn: 3914.0000 - accuracy: 0.7892 - precision: 0.7788 - recall: 0.8093 - auc: 0.8786 - val_loss: 0.4435 - val_tp: 78.0000 - val_fp: 3758.0000 - val_tn: 41724.0000 - val_fn: 9.0000 - val_accuracy: 0.9173 - val_precision: 0.0203 - val_recall: 0.8966 - val_auc: 0.9539
Epoch 6/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.3908 - tp: 16911.0000 - fp: 3861.0000 - tn: 16514.0000 - fn: 3674.0000 - accuracy: 0.8160 - precision: 0.8141 - recall: 0.8215 - auc: 0.8976 - val_loss: 0.4032 - val_tp: 79.0000 - val_fp: 2770.0000 - val_tn: 42712.0000 - val_fn: 8.0000 - val_accuracy: 0.9390 - val_precision: 0.0277 - val_recall: 0.9080 - val_auc: 0.9590
Epoch 7/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3664 - tp: 17049.0000 - fp: 3209.0000 - tn: 17179.0000 - fn: 3523.0000 - accuracy: 0.8356 - precision: 0.8416 - recall: 0.8287 - auc: 0.9108 - val_loss: 0.3682 - val_tp: 79.0000 - val_fp: 2119.0000 - val_tn: 43363.0000 - val_fn: 8.0000 - val_accuracy: 0.9533 - val_precision: 0.0359 - val_recall: 0.9080 - val_auc: 0.9634
Epoch 8/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.3467 - tp: 17100.0000 - fp: 2699.0000 - tn: 17686.0000 - fn: 3475.0000 - accuracy: 0.8493 - precision: 0.8637 - recall: 0.8311 - auc: 0.9193 - val_loss: 0.3373 - val_tp: 79.0000 - val_fp: 1753.0000 - val_tn: 43729.0000 - val_fn: 8.0000 - val_accuracy: 0.9614 - val_precision: 0.0431 - val_recall: 0.9080 - val_auc: 0.9675
Epoch 9/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.3285 - tp: 17043.0000 - fp: 2345.0000 - tn: 18228.0000 - fn: 3344.0000 - accuracy: 0.8611 - precision: 0.8790 - recall: 0.8360 - auc: 0.9271 - val_loss: 0.3104 - val_tp: 79.0000 - val_fp: 1495.0000 - val_tn: 43987.0000 - val_fn: 8.0000 - val_accuracy: 0.9670 - val_precision: 0.0502 - val_recall: 0.9080 - val_auc: 0.9702
Epoch 10/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.3094 - tp: 17322.0000 - fp: 2012.0000 - tn: 18405.0000 - fn: 3221.0000 - accuracy: 0.8722 - precision: 0.8959 - recall: 0.8432 - auc: 0.9361 - val_loss: 0.2865 - val_tp: 79.0000 - val_fp: 1332.0000 - val_tn: 44150.0000 - val_fn: 8.0000 - val_accuracy: 0.9706 - val_precision: 0.0560 - val_recall: 0.9080 - val_auc: 0.9721
Epoch 11/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2962 - tp: 17184.0000 - fp: 1757.0000 - tn: 18853.0000 - fn: 3166.0000 - accuracy: 0.8798 - precision: 0.9072 - recall: 0.8444 - auc: 0.9406 - val_loss: 0.2654 - val_tp: 79.0000 - val_fp: 1228.0000 - val_tn: 44254.0000 - val_fn: 8.0000 - val_accuracy: 0.9729 - val_precision: 0.0604 - val_recall: 0.9080 - val_auc: 0.9739
Epoch 12/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2835 - tp: 17373.0000 - fp: 1543.0000 - tn: 18909.0000 - fn: 3135.0000 - accuracy: 0.8858 - precision: 0.9184 - recall: 0.8471 - auc: 0.9458 - val_loss: 0.2469 - val_tp: 79.0000 - val_fp: 1155.0000 - val_tn: 44327.0000 - val_fn: 8.0000 - val_accuracy: 0.9745 - val_precision: 0.0640 - val_recall: 0.9080 - val_auc: 0.9759
Epoch 13/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2710 - tp: 17386.0000 - fp: 1395.0000 - tn: 19124.0000 - fn: 3055.0000 - accuracy: 0.8914 - precision: 0.9257 - recall: 0.8505 - auc: 0.9502 - val_loss: 0.2302 - val_tp: 79.0000 - val_fp: 1092.0000 - val_tn: 44390.0000 - val_fn: 8.0000 - val_accuracy: 0.9759 - val_precision: 0.0675 - val_recall: 0.9080 - val_auc: 0.9782
Epoch 14/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2618 - tp: 17336.0000 - fp: 1343.0000 - tn: 19296.0000 - fn: 2985.0000 - accuracy: 0.8943 - precision: 0.9281 - recall: 0.8531 - auc: 0.9541 - val_loss: 0.2156 - val_tp: 79.0000 - val_fp: 1053.0000 - val_tn: 44429.0000 - val_fn: 8.0000 - val_accuracy: 0.9767 - val_precision: 0.0698 - val_recall: 0.9080 - val_auc: 0.9797
Epoch 15/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2529 - tp: 17466.0000 - fp: 1154.0000 - tn: 19366.0000 - fn: 2974.0000 - accuracy: 0.8992 - precision: 0.9380 - recall: 0.8545 - auc: 0.9574 - val_loss: 0.2026 - val_tp: 79.0000 - val_fp: 1029.0000 - val_tn: 44453.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0713 - val_recall: 0.9080 - val_auc: 0.9806
Epoch 16/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2456 - tp: 17579.0000 - fp: 1075.0000 - tn: 19322.0000 - fn: 2984.0000 - accuracy: 0.9009 - precision: 0.9424 - recall: 0.8549 - auc: 0.9590 - val_loss: 0.1923 - val_tp: 79.0000 - val_fp: 1017.0000 - val_tn: 44465.0000 - val_fn: 8.0000 - val_accuracy: 0.9775 - val_precision: 0.0721 - val_recall: 0.9080 - val_auc: 0.9813
Epoch 17/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2382 - tp: 17573.0000 - fp: 982.0000 - tn: 19540.0000 - fn: 2865.0000 - accuracy: 0.9061 - precision: 0.9471 - recall: 0.8598 - auc: 0.9620 - val_loss: 0.1828 - val_tp: 79.0000 - val_fp: 1005.0000 - val_tn: 44477.0000 - val_fn: 8.0000 - val_accuracy: 0.9778 - val_precision: 0.0729 - val_recall: 0.9080 - val_auc: 0.9819
Epoch 18/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2307 - tp: 17711.0000 - fp: 966.0000 - tn: 19448.0000 - fn: 2835.0000 - accuracy: 0.9072 - precision: 0.9483 - recall: 0.8620 - auc: 0.9644 - val_loss: 0.1736 - val_tp: 80.0000 - val_fp: 990.0000 - val_tn: 44492.0000 - val_fn: 7.0000 - val_accuracy: 0.9781 - val_precision: 0.0748 - val_recall: 0.9195 - val_auc: 0.9825
Epoch 19/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2280 - tp: 17732.0000 - fp: 952.0000 - tn: 19442.0000 - fn: 2834.0000 - accuracy: 0.9076 - precision: 0.9490 - recall: 0.8622 - auc: 0.9653 - val_loss: 0.1660 - val_tp: 80.0000 - val_fp: 974.0000 - val_tn: 44508.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0759 - val_recall: 0.9195 - val_auc: 0.9826
Epoch 20/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2224 - tp: 17725.0000 - fp: 939.0000 - tn: 19538.0000 - fn: 2758.0000 - accuracy: 0.9097 - precision: 0.9497 - recall: 0.8654 - auc: 0.9667 - val_loss: 0.1591 - val_tp: 80.0000 - val_fp: 962.0000 - val_tn: 44520.0000 - val_fn: 7.0000 - val_accuracy: 0.9787 - val_precision: 0.0768 - val_recall: 0.9195 - val_auc: 0.9831
Epoch 21/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2168 - tp: 17757.0000 - fp: 826.0000 - tn: 19618.0000 - fn: 2759.0000 - accuracy: 0.9125 - precision: 0.9556 - recall: 0.8655 - auc: 0.9689 - val_loss: 0.1531 - val_tp: 80.0000 - val_fp: 967.0000 - val_tn: 44515.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0764 - val_recall: 0.9195 - val_auc: 0.9831
Epoch 22/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2112 - tp: 17833.0000 - fp: 883.0000 - tn: 19522.0000 - fn: 2722.0000 - accuracy: 0.9120 - precision: 0.9528 - recall: 0.8676 - auc: 0.9703 - val_loss: 0.1479 - val_tp: 80.0000 - val_fp: 975.0000 - val_tn: 44507.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0758 - val_recall: 0.9195 - val_auc: 0.9832
Epoch 23/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2058 - tp: 17865.0000 - fp: 835.0000 - tn: 19580.0000 - fn: 2680.0000 - accuracy: 0.9142 - precision: 0.9553 - recall: 0.8696 - auc: 0.9723 - val_loss: 0.1427 - val_tp: 80.0000 - val_fp: 977.0000 - val_tn: 44505.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0757 - val_recall: 0.9195 - val_auc: 0.9834
Epoch 24/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2053 - tp: 17856.0000 - fp: 802.0000 - tn: 19599.0000 - fn: 2703.0000 - accuracy: 0.9144 - precision: 0.9570 - recall: 0.8685 - auc: 0.9727 - val_loss: 0.1375 - val_tp: 80.0000 - val_fp: 969.0000 - val_tn: 44513.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0763 - val_recall: 0.9195 - val_auc: 0.9833
Epoch 25/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2004 - tp: 17854.0000 - fp: 809.0000 - tn: 19690.0000 - fn: 2607.0000 - accuracy: 0.9166 - precision: 0.9567 - recall: 0.8726 - auc: 0.9740 - val_loss: 0.1331 - val_tp: 80.0000 - val_fp: 976.0000 - val_tn: 44506.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0758 - val_recall: 0.9195 - val_auc: 0.9837
Epoch 26/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1991 - tp: 17857.0000 - fp: 793.0000 - tn: 19690.0000 - fn: 2620.0000 - accuracy: 0.9167 - precision: 0.9575 - recall: 0.8721 - auc: 0.9747 - val_loss: 0.1291 - val_tp: 80.0000 - val_fp: 968.0000 - val_tn: 44514.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0763 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 27/1000
20/20 [==============================] - 1s 40ms/step - loss: 0.1929 - tp: 17836.0000 - fp: 750.0000 - tn: 19833.0000 - fn: 2541.0000 - accuracy: 0.9197 - precision: 0.9596 - recall: 0.8753 - auc: 0.9760 - val_loss: 0.1252 - val_tp: 80.0000 - val_fp: 960.0000 - val_tn: 44522.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0769 - val_recall: 0.9195 - val_auc: 0.9839
Epoch 28/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.1935 - tp: 17776.0000 - fp: 753.0000 - tn: 19827.0000 - fn: 2604.0000 - accuracy: 0.9180 - precision: 0.9594 - recall: 0.8722 - auc: 0.9763 - val_loss: 0.1215 - val_tp: 80.0000 - val_fp: 946.0000 - val_tn: 44536.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0780 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 29/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1892 - tp: 17877.0000 - fp: 746.0000 - tn: 19791.0000 - fn: 2546.0000 - accuracy: 0.9196 - precision: 0.9599 - recall: 0.8753 - auc: 0.9773 - val_loss: 0.1183 - val_tp: 80.0000 - val_fp: 944.0000 - val_tn: 44538.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0781 - val_recall: 0.9195 - val_auc: 0.9840
Epoch 30/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1855 - tp: 18053.0000 - fp: 746.0000 - tn: 19673.0000 - fn: 2488.0000 - accuracy: 0.9210 - precision: 0.9603 - recall: 0.8789 - auc: 0.9779 - val_loss: 0.1157 - val_tp: 80.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 7.0000 - val_accuracy: 0.9790 - val_precision: 0.0777 - val_recall: 0.9195 - val_auc: 0.9835
Epoch 31/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1843 - tp: 18042.0000 - fp: 723.0000 - tn: 19656.0000 - fn: 2539.0000 - accuracy: 0.9204 - precision: 0.9615 - recall: 0.8766 - auc: 0.9783 - val_loss: 0.1137 - val_tp: 80.0000 - val_fp: 958.0000 - val_tn: 44524.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0771 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 32/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1831 - tp: 17974.0000 - fp: 743.0000 - tn: 19741.0000 - fn: 2502.0000 - accuracy: 0.9208 - precision: 0.9603 - recall: 0.8778 - auc: 0.9789 - val_loss: 0.1112 - val_tp: 80.0000 - val_fp: 958.0000 - val_tn: 44524.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0771 - val_recall: 0.9195 - val_auc: 0.9840
Epoch 33/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1805 - tp: 18172.0000 - fp: 775.0000 - tn: 19591.0000 - fn: 2422.0000 - accuracy: 0.9219 - precision: 0.9591 - recall: 0.8824 - auc: 0.9796 - val_loss: 0.1088 - val_tp: 81.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 6.0000 - val_accuracy: 0.9789 - val_precision: 0.0781 - val_recall: 0.9310 - val_auc: 0.9841
Epoch 34/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1749 - tp: 18125.0000 - fp: 715.0000 - tn: 19698.0000 - fn: 2422.0000 - accuracy: 0.9234 - precision: 0.9620 - recall: 0.8821 - auc: 0.9812 - val_loss: 0.1068 - val_tp: 81.0000 - val_fp: 964.0000 - val_tn: 44518.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0775 - val_recall: 0.9310 - val_auc: 0.9836
Epoch 35/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1769 - tp: 18135.0000 - fp: 715.0000 - tn: 19694.0000 - fn: 2416.0000 - accuracy: 0.9236 - precision: 0.9621 - recall: 0.8824 - auc: 0.9809 - val_loss: 0.1048 - val_tp: 81.0000 - val_fp: 978.0000 - val_tn: 44504.0000 - val_fn: 6.0000 - val_accuracy: 0.9784 - val_precision: 0.0765 - val_recall: 0.9310 - val_auc: 0.9838
Epoch 36/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1739 - tp: 18006.0000 - fp: 704.0000 - tn: 19827.0000 - fn: 2423.0000 - accuracy: 0.9237 - precision: 0.9624 - recall: 0.8814 - auc: 0.9814 - val_loss: 0.1029 - val_tp: 81.0000 - val_fp: 986.0000 - val_tn: 44496.0000 - val_fn: 6.0000 - val_accuracy: 0.9782 - val_precision: 0.0759 - val_recall: 0.9310 - val_auc: 0.9839
Epoch 37/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1687 - tp: 18002.0000 - fp: 660.0000 - tn: 19879.0000 - fn: 2419.0000 - accuracy: 0.9248 - precision: 0.9646 - recall: 0.8815 - auc: 0.9826 - val_loss: 0.1011 - val_tp: 81.0000 - val_fp: 984.0000 - val_tn: 44498.0000 - val_fn: 6.0000 - val_accuracy: 0.9783 - val_precision: 0.0761 - val_recall: 0.9310 - val_auc: 0.9841
Epoch 38/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1699 - tp: 17932.0000 - fp: 677.0000 - tn: 19986.0000 - fn: 2365.0000 - accuracy: 0.9257 - precision: 0.9636 - recall: 0.8835 - auc: 0.9825 - val_loss: 0.0995 - val_tp: 82.0000 - val_fp: 979.0000 - val_tn: 44503.0000 - val_fn: 5.0000 - val_accuracy: 0.9784 - val_precision: 0.0773 - val_recall: 0.9425 - val_auc: 0.9842
Epoch 39/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1676 - tp: 18086.0000 - fp: 736.0000 - tn: 19780.0000 - fn: 2358.0000 - accuracy: 0.9245 - precision: 0.9609 - recall: 0.8847 - auc: 0.9826 - val_loss: 0.0980 - val_tp: 82.0000 - val_fp: 975.0000 - val_tn: 44507.0000 - val_fn: 5.0000 - val_accuracy: 0.9785 - val_precision: 0.0776 - val_recall: 0.9425 - val_auc: 0.9844
Epoch 40/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1670 - tp: 18066.0000 - fp: 685.0000 - tn: 19868.0000 - fn: 2341.0000 - accuracy: 0.9261 - precision: 0.9635 - recall: 0.8853 - auc: 0.9832 - val_loss: 0.0964 - val_tp: 82.0000 - val_fp: 965.0000 - val_tn: 44517.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0783 - val_recall: 0.9425 - val_auc: 0.9845
Epoch 41/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1640 - tp: 17950.0000 - fp: 645.0000 - tn: 19995.0000 - fn: 2370.0000 - accuracy: 0.9264 - precision: 0.9653 - recall: 0.8834 - auc: 0.9839 - val_loss: 0.0950 - val_tp: 82.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0790 - val_recall: 0.9425 - val_auc: 0.9835
Epoch 42/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1641 - tp: 18083.0000 - fp: 665.0000 - tn: 19842.0000 - fn: 2370.0000 - accuracy: 0.9259 - precision: 0.9645 - recall: 0.8841 - auc: 0.9839 - val_loss: 0.0938 - val_tp: 82.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0795 - val_recall: 0.9425 - val_auc: 0.9837
Epoch 43/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1600 - tp: 18012.0000 - fp: 684.0000 - tn: 19970.0000 - fn: 2294.0000 - accuracy: 0.9273 - precision: 0.9634 - recall: 0.8870 - auc: 0.9845 - val_loss: 0.0925 - val_tp: 82.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0795 - val_recall: 0.9425 - val_auc: 0.9837
Epoch 44/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1597 - tp: 18346.0000 - fp: 657.0000 - tn: 19657.0000 - fn: 2300.0000 - accuracy: 0.9278 - precision: 0.9654 - recall: 0.8886 - auc: 0.9847 - val_loss: 0.0919 - val_tp: 82.0000 - val_fp: 955.0000 - val_tn: 44527.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0791 - val_recall: 0.9425 - val_auc: 0.9838
Epoch 45/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1607 - tp: 18109.0000 - fp: 726.0000 - tn: 19836.0000 - fn: 2289.0000 - accuracy: 0.9264 - precision: 0.9615 - recall: 0.8878 - auc: 0.9846 - val_loss: 0.0908 - val_tp: 82.0000 - val_fp: 948.0000 - val_tn: 44534.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0796 - val_recall: 0.9425 - val_auc: 0.9839
Epoch 46/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1581 - tp: 18192.0000 - fp: 650.0000 - tn: 19833.0000 - fn: 2285.0000 - accuracy: 0.9283 - precision: 0.9655 - recall: 0.8884 - auc: 0.9849 - val_loss: 0.0902 - val_tp: 82.0000 - val_fp: 955.0000 - val_tn: 44527.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0791 - val_recall: 0.9425 - val_auc: 0.9839
Epoch 47/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1579 - tp: 18301.0000 - fp: 676.0000 - tn: 19760.0000 - fn: 2223.0000 - accuracy: 0.9292 - precision: 0.9644 - recall: 0.8917 - auc: 0.9853 - val_loss: 0.0892 - val_tp: 82.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0790 - val_recall: 0.9425 - val_auc: 0.9840
Epoch 48/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1503 - tp: 18172.0000 - fp: 593.0000 - tn: 19959.0000 - fn: 2236.0000 - accuracy: 0.9309 - precision: 0.9684 - recall: 0.8904 - auc: 0.9867 - val_loss: 0.0887 - val_tp: 82.0000 - val_fp: 970.0000 - val_tn: 44512.0000 - val_fn: 5.0000 - val_accuracy: 0.9786 - val_precision: 0.0779 - val_recall: 0.9425 - val_auc: 0.9840
Epoch 49/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1572 - tp: 18217.0000 - fp: 750.0000 - tn: 19709.0000 - fn: 2284.0000 - accuracy: 0.9259 - precision: 0.9605 - recall: 0.8886 - auc: 0.9852 - val_loss: 0.0876 - val_tp: 82.0000 - val_fp: 964.0000 - val_tn: 44518.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0784 - val_recall: 0.9425 - val_auc: 0.9841
Epoch 50/1000
20/20 [==============================] - ETA: 0s - loss: 0.1529 - tp: 18230.0000 - fp: 696.0000 - tn: 19874.0000 - fn: 2160.0000 - accuracy: 0.9303 - precision: 0.9632 - recall: 0.8941 - auc: 0.9860Restoring model weights from the end of the best epoch.
20/20 [==============================] - 0s 23ms/step - loss: 0.1529 - tp: 18230.0000 - fp: 696.0000 - tn: 19874.0000 - fn: 2160.0000 - accuracy: 0.9303 - precision: 0.9632 - recall: 0.8941 - auc: 0.9860 - val_loss: 0.0860 - val_tp: 82.0000 - val_fp: 941.0000 - val_tn: 44541.0000 - val_fn: 5.0000 - val_accuracy: 0.9792 - val_precision: 0.0802 - val_recall: 0.9425 - val_auc: 0.9843
Epoch 00050: early stopping

Revérifier l'histoire de la formation

 plot_metrics(resampled_history)
 

.png

évaluer les mesures

 train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
 
 resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
 
loss :  0.09607589244842529
tp :  84.0
fp :  1195.0
tn :  55676.0
fn :  7.0
accuracy :  0.9788982272148132
precision :  0.06567630916833878
recall :  0.9230769276618958
auc :  0.9697299599647522

Legitimate Transactions Detected (True Negatives):  55676
Legitimate Transactions Incorrectly Detected (False Positives):  1195
Fraudulent Transactions Missed (False Negatives):  7
Fraudulent Transactions Detected (True Positives):  84
Total Fraudulent Transactions:  91

.png

Tracer la ROC

 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled,  color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled,  color=colors[2], linestyle='--')
plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa4bc66c9b0>

.png

L'application de ce tutoriel à votre problème

la classification des données déséquilibrées est une tâche difficile en soi car il y a si peu d'échantillons à apprendre. Vous devriez toujours commencer par les données d'abord et faire de votre mieux pour recueillir le plus d'échantillons possible et donner beaucoup de réflexion à quelles caractéristiques peuvent être pertinents, le modèle peut tirer le meilleur parti de votre classe minoritaire. À un certain moment votre modèle peut lutter pour améliorer et obtenir les résultats que vous voulez, il est donc important de garder à l'esprit le contexte de votre problème et les arbitrages entre les différents types d'erreurs.