TensorFlow 2.0 RC is available Learn more

Classification on imbalanced data

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates how to classify a highly imbalanced dataset in which the number of examples in one class greatly outnumbers the examples in another. You will work with the Credit Card Fraud Detection dataset hosted on Kaggle. The aim is to detect a mere 492 fraudulent transactions from 284,807 transactions in total. You will use Keras to define the model and class weights to help the model learn from the imbalanced data. You will display metrics for precision, recall, true positives, false positives, true negatives, false negatives, and AUC while training the model. These are more informative than accuracy when working with imbalanced datasets classification.

This tutorial contains complete code to:

  • Load a CSV file using Pandas.
  • Create train, validation, and test sets.
  • Define and train a model using Keras (including setting class weights).
  • Evaluate the model using various metrics (including precision and recall).

Import TensorFlow and other libraries

from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
!pip install -q imblearn
import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE

Use Pandas to get the Kaggle Credit Card Fraud data set

Pandas is a Python library with many helpful utilities for loading and working with structured data and can be used to download CSVs into a dataframe.

raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()

Split the dataframe into train, validation, and test

Split the dataset into train, validation, and test sets. The validation set is used during the model fitting to evaluate the loss and any metrics, however the model is not fit with this data. The test set is completely unused during the training phase and is only used at the end to evaluate how well the model generalizes to new data. This is especially important with imbalanced datasets where overfitting is a significant concern from the lack of training data.

# Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(raw_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

# Normalize the input features using the sklearn StandardScaler.
# This will set the mean to 0 and standard deviation to 1.
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 30)
Validation features shape: (45569, 30)
Test features shape: (56962, 30)

Examine the class label imbalance

Let's look at the dataset imbalance:

neg, pos = np.bincount(train_labels)
total = neg + pos
print('{} positive samples out of {} training samples ({:.2f}% of total)'.format(
    pos, total, 100 * pos / total))
324 positive samples out of 182276 training samples (0.18% of total)

This shows a small fraction of positive samples.

Define the model and metrics

Define a function that creates a simple neural network with three densely connected hidden layers, an output sigmoid layer that returns the probability of a transaction being fraudulent, and two dropout layers as an effective way to reduce overfitting.

def make_model():
  model = keras.Sequential([
      keras.layers.Dense(256, activation='relu',
                         input_shape=(train_features.shape[-1],)),
      keras.layers.Dense(256, activation='relu'),
      keras.layers.Dropout(0.3),
      keras.layers.Dense(256, activation='relu'),
      keras.layers.Dropout(0.3),
      keras.layers.Dense(1, activation='sigmoid'),
  ])

  metrics = [
      keras.metrics.Accuracy(name='accuracy'),
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc')
  ]

  model.compile(
      optimizer='adam',
      loss='binary_crossentropy',
      metrics=metrics)
  
  return model

Understanding useful metrics

Notice that there are a few metrics defined above that can be computed by the model that will be helpful when evaluating the performance.

  • False negatives and false positives are samples that were incorrectly classified
  • True negatives and true positives are samples that were correctly classified
  • Accuracy is the percentage of examples correctly classified > $\frac{\text{true samples}}{\text{total samples}}$
  • Precision is the percentage of predicted positives that were correctly classified > $\frac{\text{true positives}}{\text{true positives + false positives}}$
  • Recall is the percentage of actual positives that were correctly classified > $\frac{\text{true positives}}{\text{true positives + false negatives}}$
  • AUC refers to the Area Under the Curve of a Receiver Operating Characteristic curve (ROC-AUC). This metric is equal to the probability that a classifier will rank a random positive sample higher than than a random negative sample.


Read more: * True vs. False and Positive vs. Negative * Accuracy * Precision and Recall * ROC-AUC

Train a baseline model

Now create and train your model using the function that was defined earlier. Notice that the model is fit using a larger than default batch size of 2048, this is important to ensure that each batch has a decent chance of containing a few positive samples. If the batch size was too small, they would likely have no fraudelent transactions to learn from.

model = make_model()

EPOCHS = 10
BATCH_SIZE = 2048

history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(val_features, val_labels))
WARNING: Logging before flag parsing goes to stderr.
W0815 01:59:02.710840 139929069790976 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Train on 182276 samples, validate on 45569 samples
Epoch 1/10
182276/182276 [==============================] - 3s 14us/sample - loss: 0.0447 - accuracy: 0.3768 - tp: 1.0000 - fp: 622.0000 - tn: 181330.0000 - fn: 323.0000 - precision: 0.0016 - recall: 0.0031 - auc: 0.4725 - val_loss: 0.0250 - val_accuracy: 0.5031 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 45495.0000 - val_fn: 74.0000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5000
Epoch 2/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0273 - accuracy: 0.3979 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 181952.0000 - fn: 324.0000 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.5000 - val_loss: 0.0247 - val_accuracy: 0.2355 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 45495.0000 - val_fn: 74.0000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5000
Epoch 3/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0175 - accuracy: 0.0895 - tp: 2.0000 - fp: 0.0000e+00 - tn: 181952.0000 - fn: 322.0000 - precision: 1.0000 - recall: 0.0062 - auc: 0.6654 - val_loss: 0.0036 - val_accuracy: 0.0253 - val_tp: 28.0000 - val_fp: 11.0000 - val_tn: 45484.0000 - val_fn: 46.0000 - val_precision: 0.7179 - val_recall: 0.3784 - val_auc: 0.9446
Epoch 4/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0039 - accuracy: 0.0349 - tp: 221.0000 - fp: 38.0000 - tn: 181914.0000 - fn: 103.0000 - precision: 0.8533 - recall: 0.6821 - auc: 0.9446 - val_loss: 0.0027 - val_accuracy: 0.0330 - val_tp: 63.0000 - val_fp: 20.0000 - val_tn: 45475.0000 - val_fn: 11.0000 - val_precision: 0.7590 - val_recall: 0.8514 - val_auc: 0.9521
Epoch 5/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0031 - accuracy: 0.0509 - tp: 248.0000 - fp: 41.0000 - tn: 181911.0000 - fn: 76.0000 - precision: 0.8581 - recall: 0.7654 - auc: 0.9529 - val_loss: 0.0027 - val_accuracy: 0.0478 - val_tp: 63.0000 - val_fp: 20.0000 - val_tn: 45475.0000 - val_fn: 11.0000 - val_precision: 0.7590 - val_recall: 0.8514 - val_auc: 0.9588
Epoch 6/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0028 - accuracy: 0.0631 - tp: 254.0000 - fp: 43.0000 - tn: 181909.0000 - fn: 70.0000 - precision: 0.8552 - recall: 0.7840 - auc: 0.9638 - val_loss: 0.0027 - val_accuracy: 0.0768 - val_tp: 63.0000 - val_fp: 18.0000 - val_tn: 45477.0000 - val_fn: 11.0000 - val_precision: 0.7778 - val_recall: 0.8514 - val_auc: 0.9589
Epoch 7/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0026 - accuracy: 0.1093 - tp: 253.0000 - fp: 38.0000 - tn: 181914.0000 - fn: 71.0000 - precision: 0.8694 - recall: 0.7809 - auc: 0.9655 - val_loss: 0.0029 - val_accuracy: 0.0861 - val_tp: 63.0000 - val_fp: 21.0000 - val_tn: 45474.0000 - val_fn: 11.0000 - val_precision: 0.7500 - val_recall: 0.8514 - val_auc: 0.9588
Epoch 8/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0024 - accuracy: 0.1176 - tp: 258.0000 - fp: 38.0000 - tn: 181914.0000 - fn: 66.0000 - precision: 0.8716 - recall: 0.7963 - auc: 0.9686 - val_loss: 0.0030 - val_accuracy: 0.1358 - val_tp: 63.0000 - val_fp: 22.0000 - val_tn: 45473.0000 - val_fn: 11.0000 - val_precision: 0.7412 - val_recall: 0.8514 - val_auc: 0.9588
Epoch 9/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0022 - accuracy: 0.1672 - tp: 259.0000 - fp: 42.0000 - tn: 181910.0000 - fn: 65.0000 - precision: 0.8605 - recall: 0.7994 - auc: 0.9779 - val_loss: 0.0028 - val_accuracy: 0.1543 - val_tp: 63.0000 - val_fp: 18.0000 - val_tn: 45477.0000 - val_fn: 11.0000 - val_precision: 0.7778 - val_recall: 0.8514 - val_auc: 0.9589
Epoch 10/10
182276/182276 [==============================] - 1s 4us/sample - loss: 0.0020 - accuracy: 0.2022 - tp: 259.0000 - fp: 34.0000 - tn: 181918.0000 - fn: 65.0000 - precision: 0.8840 - recall: 0.7994 - auc: 0.9826 - val_loss: 0.0028 - val_accuracy: 0.2005 - val_tp: 58.0000 - val_fp: 16.0000 - val_tn: 45479.0000 - val_fn: 16.0000 - val_precision: 0.7838 - val_recall: 0.7838 - val_auc: 0.9455

Plot metrics on the training and validation sets

In this section, you will produce plots of your model's accuracy and loss on the training and validation set. These are useful to check for overfitting, which you can learn more about in this tutorial.

Additionally, you can produce these plots for any of the metrics you created above. False negatives are included as an example.

epochs = range(EPOCHS)

plt.title('Accuracy')
plt.plot(epochs,  history.history['accuracy'], color='blue', label='Train')
plt.plot(epochs, history.history['val_accuracy'], color='orange', label='Val')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

_ = plt.figure()
plt.title('Loss')
plt.plot(epochs, history.history['loss'], color='blue', label='Train')
plt.plot(epochs, history.history['val_loss'], color='orange', label='Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

_ = plt.figure()
plt.title('False Negatives')
plt.plot(epochs, history.history['fn'], color='blue', label='Train')
plt.plot(epochs, history.history['val_fn'], color='orange', label='Val')
plt.xlabel('Epoch')
plt.ylabel('False Negatives')
plt.legend()
<matplotlib.legend.Legend at 0x7f432a991fd0>

png

png

png

Evaluate the baseline model

Evaluate your model on the test dataset and display results for the metrics you created above.

results = model.evaluate(test_features, test_labels)
for name, value in zip(model.metrics_names, results):
  print(name, ': ', value)
56962/56962 [==============================] - 5s 95us/sample - loss: 0.0033 - accuracy: 0.1993 - tp: 77.0000 - fp: 11.0000 - tn: 56857.0000 - fn: 17.0000 - precision: 0.8750 - recall: 0.8191 - auc: 0.9357
loss :  0.0032539447209801215
accuracy :  0.19929075
tp :  77.0
fp :  11.0
tn :  56857.0
fn :  17.0
precision :  0.875
recall :  0.81914896
auc :  0.9357286

It looks like the precision is relatively high, but the recall and AUC aren't as high as you might like. Classifiers often face challenges when trying to maximize both precision and recall, which is especially true when working with imbalanced datasets. However, because missing fraudulent transactions (false negatives) may have significantly worse business consequences than incorrectly flagging fraudulent transactions (false positives), recall may be more important than precision in this case.

Examine the confusion matrix

You can use a confusion matrix to summarize the actual vs. predicted labels where the X axis is the predicted label and the Y axis is the actual label.

predicted_labels = model.predict(test_features)
cm = confusion_matrix(test_labels, np.round(predicted_labels))

plt.matshow(cm, alpha=0)
plt.title('Confusion matrix')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')

for (i, j), z in np.ndenumerate(cm):
    plt.text(j, i, str(z), ha='center', va='center')
    
plt.show()

print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
print('Total Fraudulent Transactions: ', np.sum(cm[1]))

png

Legitimate Transactions Detected (True Negatives):  56857
Legitimate Transactions Incorrectly Detected (False Positives):  11
Fraudulent Transactions Missed (False Negatives):  17
Fraudulent Transactions Detected (True Positives):  77
Total Fraudulent Transactions:  94

If the model had predicted everything perfectly, this would be a diagonal matrix where values off the main diagonal, indicating incorrect predictions, would be zero. In this case the matrix shows that you have relatively few false positives, meaning that there were relatively few legitimate transactions that were incorrectly flagged. However, you would likely want to have even fewer false negatives despite the cost of increasing the number of false positives. This trade off may be preferable because false negatives would allow fraudulent transactions to go through, whereas false positives may cause an email to be sent to a customer to ask them to verify their card activity.

Using class weights for the loss function

The goal is to identify fradulent transactions, but you don't have very many of those positive samples to work with, so you would want to have the classifier heavily weight the few examples that are available. You can do this by passing Keras weights for each class through a parameter. These will cause the model to "pay more attention" to examples from an under-represented class.


weight_for_0 = 1 / neg
weight_for_1 = 1 / pos

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2e}'.format(weight_for_0))
print('Weight for class 1: {:.2e}'.format(weight_for_1))
Weight for class 0: 5.50e-06
Weight for class 1: 3.09e-03

Train a model with class weights

Now try re-training and evaluating the model with class weights to see how that affects the predictions.

weighted_model = make_model()

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(val_features, val_labels),
    class_weight=class_weight)
Train on 182276 samples, validate on 45569 samples
Epoch 1/10
182276/182276 [==============================] - 2s 10us/sample - loss: 4.8680e-06 - accuracy: 0.0000e+00 - tp: 301.0000 - fp: 84053.0000 - tn: 97899.0000 - fn: 23.0000 - precision: 0.0036 - recall: 0.9290 - auc: 0.9190 - val_loss: 0.4029 - val_accuracy: 0.0000e+00 - val_tp: 65.0000 - val_fp: 584.0000 - val_tn: 44911.0000 - val_fn: 9.0000 - val_precision: 0.1002 - val_recall: 0.8784 - val_auc: 0.9527
Epoch 2/10
182276/182276 [==============================] - 1s 4us/sample - loss: 2.9306e-06 - accuracy: 1.8104e-04 - tp: 281.0000 - fp: 3709.0000 - tn: 178243.0000 - fn: 43.0000 - precision: 0.0704 - recall: 0.8673 - auc: 0.9478 - val_loss: 0.1793 - val_accuracy: 3.0723e-04 - val_tp: 66.0000 - val_fp: 720.0000 - val_tn: 44775.0000 - val_fn: 8.0000 - val_precision: 0.0840 - val_recall: 0.8919 - val_auc: 0.9614
Epoch 3/10
182276/182276 [==============================] - 1s 4us/sample - loss: 1.9709e-06 - accuracy: 4.3341e-04 - tp: 287.0000 - fp: 4288.0000 - tn: 177664.0000 - fn: 37.0000 - precision: 0.0627 - recall: 0.8858 - auc: 0.9776 - val_loss: 0.1259 - val_accuracy: 6.1445e-04 - val_tp: 67.0000 - val_fp: 930.0000 - val_tn: 44565.0000 - val_fn: 7.0000 - val_precision: 0.0672 - val_recall: 0.9054 - val_auc: 0.9744
Epoch 4/10
182276/182276 [==============================] - 1s 4us/sample - loss: 1.6452e-06 - accuracy: 6.3091e-04 - tp: 297.0000 - fp: 4622.0000 - tn: 177330.0000 - fn: 27.0000 - precision: 0.0604 - recall: 0.9167 - auc: 0.9848 - val_loss: 0.1077 - val_accuracy: 7.0223e-04 - val_tp: 67.0000 - val_fp: 1033.0000 - val_tn: 44462.0000 - val_fn: 7.0000 - val_precision: 0.0609 - val_recall: 0.9054 - val_auc: 0.9793
Epoch 5/10
182276/182276 [==============================] - 1s 4us/sample - loss: 1.5432e-06 - accuracy: 7.0772e-04 - tp: 295.0000 - fp: 4379.0000 - tn: 177573.0000 - fn: 29.0000 - precision: 0.0631 - recall: 0.9105 - auc: 0.9857 - val_loss: 0.0875 - val_accuracy: 7.2418e-04 - val_tp: 67.0000 - val_fp: 874.0000 - val_tn: 44621.0000 - val_fn: 7.0000 - val_precision: 0.0712 - val_recall: 0.9054 - val_auc: 0.9817
Epoch 6/10
182276/182276 [==============================] - 1s 4us/sample - loss: 1.3898e-06 - accuracy: 7.4064e-04 - tp: 302.0000 - fp: 4198.0000 - tn: 177754.0000 - fn: 22.0000 - precision: 0.0671 - recall: 0.9321 - auc: 0.9873 - val_loss: 0.0654 - val_accuracy: 9.2168e-04 - val_tp: 67.0000 - val_fp: 605.0000 - val_tn: 44890.0000 - val_fn: 7.0000 - val_precision: 0.0997 - val_recall: 0.9054 - val_auc: 0.9850
Epoch 7/10
182276/182276 [==============================] - 1s 4us/sample - loss: 1.3700e-06 - accuracy: 9.9849e-04 - tp: 302.0000 - fp: 4141.0000 - tn: 177811.0000 - fn: 22.0000 - precision: 0.0680 - recall: 0.9321 - auc: 0.9884 - val_loss: 0.1083 - val_accuracy: 9.2168e-04 - val_tp: 67.0000 - val_fp: 1255.0000 - val_tn: 44240.0000 - val_fn: 7.0000 - val_precision: 0.0507 - val_recall: 0.9054 - val_auc: 0.9856
Epoch 8/10
182276/182276 [==============================] - 1s 4us/sample - loss: 1.0712e-06 - accuracy: 0.0012 - tp: 303.0000 - fp: 3368.0000 - tn: 178584.0000 - fn: 21.0000 - precision: 0.0825 - recall: 0.9352 - auc: 0.9949 - val_loss: 0.1142 - val_accuracy: 0.0012 - val_tp: 68.0000 - val_fp: 1531.0000 - val_tn: 43964.0000 - val_fn: 6.0000 - val_precision: 0.0425 - val_recall: 0.9189 - val_auc: 0.9869
Epoch 9/10
182276/182276 [==============================] - 1s 4us/sample - loss: 1.0397e-06 - accuracy: 0.0018 - tp: 309.0000 - fp: 4053.0000 - tn: 177899.0000 - fn: 15.0000 - precision: 0.0708 - recall: 0.9537 - auc: 0.9931 - val_loss: 0.0778 - val_accuracy: 0.0015 - val_tp: 68.0000 - val_fp: 932.0000 - val_tn: 44563.0000 - val_fn: 6.0000 - val_precision: 0.0680 - val_recall: 0.9189 - val_auc: 0.9869
Epoch 10/10
182276/182276 [==============================] - 1s 4us/sample - loss: 9.6902e-07 - accuracy: 0.0020 - tp: 312.0000 - fp: 3426.0000 - tn: 178526.0000 - fn: 12.0000 - precision: 0.0835 - recall: 0.9630 - auc: 0.9940 - val_loss: 0.1075 - val_accuracy: 0.0017 - val_tp: 68.0000 - val_fp: 1477.0000 - val_tn: 44018.0000 - val_fn: 6.0000 - val_precision: 0.0440 - val_recall: 0.9189 - val_auc: 0.9827
weighted_results = weighted_model.evaluate(test_features, test_labels)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
56962/56962 [==============================] - 5s 94us/sample - loss: 0.1060 - accuracy: 0.0017 - tp: 85.0000 - fp: 1879.0000 - tn: 54989.0000 - fn: 9.0000 - precision: 0.0433 - recall: 0.9043 - auc: 0.9527
loss :  0.1059855450376241
accuracy :  0.0016853341
tp :  85.0
fp :  1879.0
tn :  54989.0
fn :  9.0
precision :  0.043279022
recall :  0.90425533
auc :  0.95272493

Here you can see that with class weights the accuracy and precision are lower because there are more false positives, but conversely the recall and AUC are higher because the model also found more true positives. Despite having lower overall accuracy, this approach may be better when considering the consequences of failing to identify fraudulent transactions driving the prioritization of recall. Depending on how bad false negatives are, you might use even more exaggerated weights to further improve recall while dropping precision.

Oversampling the minority class

A related approach would be to resample the dataset by oversampling the minority class, which is the process of creating more positive samples using something like sklearn's imbalanced-learn library. This library provides methods to create new positive samples by simply duplicating random existing samples, or by interpolating between them to generate synthetic samples using variations of SMOTE. TensorFlow also provides a way to do Random Oversampling.

# with default args this will oversample the minority class to have an equal
# number of observations
smote = SMOTE()
res_features, res_labels = smote.fit_sample(train_features, train_labels)

res_neg, res_pos = np.bincount(res_labels)
res_total = res_neg + res_pos
print('{} positive samples out of {} training samples ({:.2f}% of total)'.format(
    res_pos, res_total, 100 * res_pos / res_total))
181952 positive samples out of 363904 training samples (50.00% of total)

Train and evaluate a model on the resampled data

Now try training the model with the resampled data set instead of using class weights to see how these methods compare.

resampled_model = make_model()

resampled_history = resampled_model.fit(
    res_features,
    res_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(val_features, val_labels))
Train on 363904 samples, validate on 45569 samples
Epoch 1/10
363904/363904 [==============================] - 2s 6us/sample - loss: 0.0625 - accuracy: 0.3698 - tp: 177708.0000 - fp: 5145.0000 - tn: 176807.0000 - fn: 4244.0000 - precision: 0.9719 - recall: 0.9767 - auc: 0.9976 - val_loss: 0.0231 - val_accuracy: 0.4311 - val_tp: 67.0000 - val_fp: 241.0000 - val_tn: 45254.0000 - val_fn: 7.0000 - val_precision: 0.2175 - val_recall: 0.9054 - val_auc: 0.9508
Epoch 2/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0058 - accuracy: 0.7162 - tp: 181922.0000 - fp: 341.0000 - tn: 181611.0000 - fn: 30.0000 - precision: 0.9981 - recall: 0.9998 - auc: 0.9998 - val_loss: 0.0118 - val_accuracy: 0.7915 - val_tp: 65.0000 - val_fp: 58.0000 - val_tn: 45437.0000 - val_fn: 9.0000 - val_precision: 0.5285 - val_recall: 0.8784 - val_auc: 0.9385
Epoch 3/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0038 - accuracy: 0.7982 - tp: 181939.0000 - fp: 183.0000 - tn: 181769.0000 - fn: 13.0000 - precision: 0.9990 - recall: 0.9999 - auc: 0.9998 - val_loss: 0.0121 - val_accuracy: 0.8538 - val_tp: 65.0000 - val_fp: 70.0000 - val_tn: 45425.0000 - val_fn: 9.0000 - val_precision: 0.4815 - val_recall: 0.8784 - val_auc: 0.9454
Epoch 4/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0030 - accuracy: 0.8285 - tp: 181935.0000 - fp: 128.0000 - tn: 181824.0000 - fn: 17.0000 - precision: 0.9993 - recall: 0.9999 - auc: 0.9999 - val_loss: 0.0108 - val_accuracy: 0.9174 - val_tp: 65.0000 - val_fp: 48.0000 - val_tn: 45447.0000 - val_fn: 9.0000 - val_precision: 0.5752 - val_recall: 0.8784 - val_auc: 0.9388
Epoch 5/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0028 - accuracy: 0.8472 - tp: 181938.0000 - fp: 121.0000 - tn: 181831.0000 - fn: 14.0000 - precision: 0.9993 - recall: 0.9999 - auc: 0.9999 - val_loss: 0.0109 - val_accuracy: 0.9348 - val_tp: 65.0000 - val_fp: 50.0000 - val_tn: 45445.0000 - val_fn: 9.0000 - val_precision: 0.5652 - val_recall: 0.8784 - val_auc: 0.9455
Epoch 6/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0026 - accuracy: 0.8487 - tp: 181934.0000 - fp: 106.0000 - tn: 181846.0000 - fn: 18.0000 - precision: 0.9994 - recall: 0.9999 - auc: 0.9999 - val_loss: 0.0104 - val_accuracy: 0.9466 - val_tp: 65.0000 - val_fp: 43.0000 - val_tn: 45452.0000 - val_fn: 9.0000 - val_precision: 0.6019 - val_recall: 0.8784 - val_auc: 0.9456
Epoch 7/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0024 - accuracy: 0.8553 - tp: 181934.0000 - fp: 98.0000 - tn: 181854.0000 - fn: 18.0000 - precision: 0.9995 - recall: 0.9999 - auc: 0.9999 - val_loss: 0.0095 - val_accuracy: 0.9564 - val_tp: 65.0000 - val_fp: 33.0000 - val_tn: 45462.0000 - val_fn: 9.0000 - val_precision: 0.6633 - val_recall: 0.8784 - val_auc: 0.9389
Epoch 8/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0020 - accuracy: 0.8544 - tp: 181936.0000 - fp: 82.0000 - tn: 181870.0000 - fn: 16.0000 - precision: 0.9995 - recall: 0.9999 - auc: 0.9999 - val_loss: 0.0116 - val_accuracy: 0.9575 - val_tp: 65.0000 - val_fp: 48.0000 - val_tn: 45447.0000 - val_fn: 9.0000 - val_precision: 0.5752 - val_recall: 0.8784 - val_auc: 0.9388
Epoch 9/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0017 - accuracy: 0.8495 - tp: 181942.0000 - fp: 58.0000 - tn: 181894.0000 - fn: 10.0000 - precision: 0.9997 - recall: 0.9999 - auc: 0.9999 - val_loss: 0.0091 - val_accuracy: 0.9658 - val_tp: 65.0000 - val_fp: 30.0000 - val_tn: 45465.0000 - val_fn: 9.0000 - val_precision: 0.6842 - val_recall: 0.8784 - val_auc: 0.9389
Epoch 10/10
363904/363904 [==============================] - 1s 3us/sample - loss: 0.0015 - accuracy: 0.7874 - tp: 181948.0000 - fp: 53.0000 - tn: 181899.0000 - fn: 4.0000 - precision: 0.9997 - recall: 1.0000 - auc: 0.9999 - val_loss: 0.0097 - val_accuracy: 0.9668 - val_tp: 65.0000 - val_fp: 42.0000 - val_tn: 45453.0000 - val_fn: 9.0000 - val_precision: 0.6075 - val_recall: 0.8784 - val_auc: 0.9456
resampled_results = resampled_model.evaluate(test_features, test_labels)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
56962/56962 [==============================] - 5s 96us/sample - loss: 0.0079 - accuracy: 0.9684 - tp: 80.0000 - fp: 34.0000 - tn: 56834.0000 - fn: 14.0000 - precision: 0.7018 - recall: 0.8511 - auc: 0.9306
loss :  0.007876014656974763
accuracy :  0.9683649
tp :  80.0
fp :  34.0
tn :  56834.0
fn :  14.0
precision :  0.7017544
recall :  0.85106385
auc :  0.9305925

This approach can be worth trying, but may not provide better results than using class weights because the synthetic examples may not accurately represent the underlying data.

Applying this tutorial to your problem

Imbalanced data classification is an inherantly difficult task since there are so few samples to learn from. You should always start with the data first and do your best to collect as many samples as possible and give substantial thought to what features may be relevant so the model can get the most out of your minority class. At some point your model may struggle to improve and yield the results you want, so it is important to keep in mind the context of the problem to evaluate how bad your false positives or negatives really are.