Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Klasifikasi pada data yang tidak seimbang

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Tutorial ini menunjukkan cara mengklasifikasikan kumpulan data yang sangat tidak seimbang di mana jumlah contoh di satu kelas jauh lebih banyak daripada contoh di kelas lain. Anda akan bekerja dengan kumpulan data Deteksi Penipuan Kartu Kredit yang dihosting di Kaggle. Tujuannya, untuk mendeteksi hanya 492 transaksi penipuan dari total 284.807 transaksi. Anda akan menggunakan Keras untuk menentukan model dan bobot kelas untuk membantu model belajar dari data yang tidak seimbang. .

Tutorial ini berisi kode lengkap untuk:

  • Muat file CSV menggunakan Pandas.
  • Buat set pelatihan, validasi, dan pengujian.
  • Tentukan dan latih model menggunakan Keras (termasuk menyetel bobot kelas).
  • Evaluasi model menggunakan berbagai metrik (termasuk presisi dan perolehan).
  • Cobalah teknik umum untuk menangani data yang tidak seimbang seperti:
    • Pembobotan kelas
    • Oversampling

Mendirikan

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

Pemrosesan dan eksplorasi data

Unduh kumpulan data Penipuan Kartu Kredit Kaggle

Pandas adalah pustaka Python dengan banyak utilitas bermanfaat untuk memuat dan bekerja dengan data terstruktur dan dapat digunakan untuk mengunduh CSV ke dalam kerangka data.

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

Periksa ketidakseimbangan label kelas

Mari kita lihat ketidakseimbangan dataset:

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)


Ini menunjukkan sebagian kecil dari sampel positif.

Bersihkan, pisahkan, dan normalkan data

Data mentah memiliki beberapa masalah. Pertama, kolom Time dan Amount terlalu bervariasi untuk digunakan secara langsung. Jatuhkan kolom Time (karena tidak jelas apa artinya) dan ambil log dari kolom Amount untuk mengurangi jangkauannya.

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)

Pisahkan kumpulan data menjadi set pelatihan, validasi, dan pengujian. Set validasi digunakan selama penyesuaian model untuk mengevaluasi kerugian dan metrik apa pun, namun model tersebut tidak cocok dengan data ini. Set pengujian sama sekali tidak digunakan selama fase pelatihan dan hanya digunakan di bagian akhir untuk mengevaluasi seberapa baik model menggeneralisasi data baru. Ini terutama penting dengan kumpulan data yang tidak seimbang di mana overfitting menjadi perhatian yang signifikan dari kurangnya data pelatihan.

# Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

Normalisasi fitur masukan menggunakan sklearn StandardScaler. Ini akan menetapkan mean menjadi 0 dan deviasi standar menjadi 1.

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

Lihatlah distribusi datanya

Selanjutnya bandingkan distribusi contoh positif dan negatif melalui beberapa fitur. Pertanyaan bagus untuk ditanyakan pada diri Anda saat ini adalah:

  • Apakah distribusi ini masuk akal?
    • Iya. Anda telah menormalkan input dan ini sebagian besar terkonsentrasi di kisaran +/- 2 .
  • Bisakah Anda melihat perbedaan antara distribusi?
    • Ya, contoh positif mengandung tingkat nilai ekstrim yang jauh lebih tinggi.
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning

png

png

Tentukan model dan metrik

Mendefinisikan sebuah fungsi yang menciptakan jaringan saraf yang sederhana dengan lapisan terhubung densly tersembunyi, sebuah jebolan lapisan untuk mengurangi overfitting, dan output lapisan sigmoid bahwa pengembalian kemungkinan transaksi menjadi penipuan:

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics=METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

Memahami metrik yang berguna

Perhatikan bahwa ada beberapa metrik yang ditentukan di atas yang dapat dihitung oleh model yang akan berguna saat mengevaluasi kinerja.

  • Negatif palsu dan positif palsu adalah contoh yang salah diklasifikasikan
  • Benar negatif dan positif sejati adalah contoh yang benar diklasifikasikan
  • Akurasi adalah persentase contoh yang diklasifikasikan dengan benar> $ \ frac {\ text {sampel benar}} {\ text {total sampel}} $
  • Presisi adalah persentase prediksi positif yang diklasifikasikan dengan benar> $ \ frac {\ text {positif benar}} {\ text {positif benar + positif palsu}} $
  • Perolehan adalah persentase positif sebenarnya yang diklasifikasikan dengan benar> $ \ frac {\ text {positif benar}} {\ text {positif benar + negatif palsu}} $
  • ABK mengacu pada Area di Bawah Kurva Kurva Karakteristik Operasi Penerima (ROC-AUC). Metrik ini sama dengan probabilitas bahwa pengklasifikasi akan memberi peringkat sampel positif acak lebih tinggi daripada sampel negatif acak.

Baca lebih banyak:

Model dasar

Bangun modelnya

Sekarang buat dan latih model Anda menggunakan fungsi yang telah ditentukan sebelumnya. Perhatikan bahwa model tersebut cocok menggunakan lebih besar dari ukuran kumpulan default 2048, hal ini penting untuk memastikan bahwa setiap batch memiliki peluang yang layak untuk memuat beberapa sampel positif. Jika ukuran batch terlalu kecil, kemungkinan besar mereka tidak memiliki transaksi penipuan untuk dipelajari.

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________

Uji coba model:

model.predict(train_features[:10])
array([[0.6983068 ],
       [0.7546284 ],
       [0.73785573],
       [0.7908986 ],
       [0.51232255],
       [0.752192  ],
       [0.7387281 ],
       [0.9410955 ],
       [0.809352  ],
       [0.6911539 ]], dtype=float32)

Opsional: Setel bias awal yang benar.

Tebakan awal ini tidak bagus. Anda tahu bahwa set data tidak seimbang. Setel bias lapisan keluaran untuk mencerminkan itu (Lihat: Resep untuk Melatih Jaringan Syaraf Tiruan: "init well" ). Ini dapat membantu konvergensi awal.

Dengan inisialisasi bias default, kerugian seharusnya sekitar math.log(2) = 0.69314

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 1.5998

Bias yang benar untuk ditetapkan dapat diturunkan dari:

$$ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) $$
$$ b_0 = -log_e(1/p_0 - 1) $$
$$ b_0 = log_e(pos/neg)$$
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

Tetapkan itu sebagai bias awal, dan model akan memberikan tebakan awal yang lebih masuk akal.

Ini harus dekat: pos/total = 0.0018

model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
array([[0.00168876],
       [0.00081124],
       [0.00087036],
       [0.00241473],
       [0.00133016],
       [0.00121771],
       [0.00079989],
       [0.00079692],
       [0.00257652],
       [0.00104385]], dtype=float32)

Dengan inisialisasi ini kerugian awal harus kira-kira:

$$-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317$$
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0174

Kerugian awal ini sekitar 50 kali lebih sedikit dibandingkan jika terjadi inisialisasi yang naif.

Dengan cara ini, model tidak perlu menghabiskan beberapa periode pertama hanya untuk mempelajari bahwa contoh positif tidak mungkin terjadi. Ini juga memudahkan untuk membaca plot kerugian selama pelatihan.

Periksa bobot awal

Untuk membuat berbagai pelatihan berjalan lebih sebanding, simpan bobot model awal ini dalam file checkpoint, dan muat ke setiap model sebelum pelatihan.

initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)

Konfirmasikan bahwa perbaikan bias membantu

Sebelum melanjutkan, konfirmasikan dengan cepat bahwa inisialisasi bias yang cermat benar-benar membantu.

Latih model untuk 20 epoch, dengan dan tanpa inisialisasi yang cermat ini, dan bandingkan kerugiannya:

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale on y-axis to show the wide range of values.
  plt.semilogy(history.epoch, history.history['loss'],
               color=colors[n], label='Train ' + label)
  plt.semilogy(history.epoch, history.history['val_loss'],
               color=colors[n], label='Val ' + label,
               linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

png

Gambar di atas menjelaskan: Dalam hal kehilangan validasi, pada masalah ini, inisialisasi yang cermat ini memberikan keuntungan yang jelas.

Latih modelnya

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 3s 15ms/step - loss: 0.0186 - tp: 64.0000 - fp: 25.1978 - tn: 139431.9780 - fn: 188.3956 - accuracy: 0.9985 - precision: 0.7214 - recall: 0.3037 - auc: 0.6752 - val_loss: 0.0109 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 45489.0000 - val_fn: 80.0000 - val_accuracy: 0.9982 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.6977
Epoch 2/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0085 - tp: 25.4505 - fp: 5.6703 - tn: 93973.1538 - fn: 136.2967 - accuracy: 0.9985 - precision: 0.6137 - recall: 0.1108 - auc: 0.8194 - val_loss: 0.0051 - val_tp: 27.0000 - val_fp: 9.0000 - val_tn: 45480.0000 - val_fn: 53.0000 - val_accuracy: 0.9986 - val_precision: 0.7500 - val_recall: 0.3375 - val_auc: 0.9184
Epoch 3/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0057 - tp: 68.4505 - fp: 9.9231 - tn: 93970.2198 - fn: 91.9780 - accuracy: 0.9989 - precision: 0.8896 - recall: 0.4278 - auc: 0.9133 - val_loss: 0.0043 - val_tp: 44.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 36.0000 - val_accuracy: 0.9990 - val_precision: 0.8148 - val_recall: 0.5500 - val_auc: 0.9185
Epoch 4/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0051 - tp: 86.9451 - fp: 17.4505 - tn: 93959.1429 - fn: 77.0330 - accuracy: 0.9990 - precision: 0.8422 - recall: 0.5397 - auc: 0.9191 - val_loss: 0.0040 - val_tp: 51.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 29.0000 - val_accuracy: 0.9991 - val_precision: 0.8361 - val_recall: 0.6375 - val_auc: 0.9248
Epoch 5/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0053 - tp: 83.8571 - fp: 13.0110 - tn: 93965.0879 - fn: 78.6154 - accuracy: 0.9990 - precision: 0.8890 - recall: 0.5114 - auc: 0.9212 - val_loss: 0.0039 - val_tp: 55.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 25.0000 - val_accuracy: 0.9992 - val_precision: 0.8462 - val_recall: 0.6875 - val_auc: 0.9247
Epoch 6/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0049 - tp: 83.9890 - fp: 11.9451 - tn: 93976.1648 - fn: 68.4725 - accuracy: 0.9992 - precision: 0.8793 - recall: 0.5693 - auc: 0.9022 - val_loss: 0.0038 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8406 - val_recall: 0.7250 - val_auc: 0.9247
Epoch 7/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0053 - tp: 89.0549 - fp: 19.6484 - tn: 93958.4835 - fn: 73.3846 - accuracy: 0.9991 - precision: 0.8187 - recall: 0.5711 - auc: 0.8824 - val_loss: 0.0036 - val_tp: 53.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 27.0000 - val_accuracy: 0.9992 - val_precision: 0.8413 - val_recall: 0.6625 - val_auc: 0.9309
Epoch 8/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0040 - tp: 92.4396 - fp: 11.4835 - tn: 93970.6374 - fn: 66.0110 - accuracy: 0.9992 - precision: 0.9188 - recall: 0.5617 - auc: 0.9298 - val_loss: 0.0036 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8529 - val_recall: 0.7250 - val_auc: 0.9309
Epoch 9/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 100.4396 - fp: 14.7802 - tn: 93963.6044 - fn: 61.7473 - accuracy: 0.9992 - precision: 0.8725 - recall: 0.6228 - auc: 0.9167 - val_loss: 0.0035 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8406 - val_recall: 0.7250 - val_auc: 0.9247
Epoch 10/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0046 - tp: 94.8132 - fp: 12.7253 - tn: 93958.3187 - fn: 74.7143 - accuracy: 0.9991 - precision: 0.8935 - recall: 0.5710 - auc: 0.9196 - val_loss: 0.0035 - val_tp: 61.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8472 - val_recall: 0.7625 - val_auc: 0.9247
Epoch 11/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0047 - tp: 97.0440 - fp: 16.3956 - tn: 93958.9231 - fn: 68.2088 - accuracy: 0.9991 - precision: 0.8591 - recall: 0.5946 - auc: 0.9160 - val_loss: 0.0034 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8529 - val_recall: 0.7250 - val_auc: 0.9247
Epoch 12/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0040 - tp: 101.0769 - fp: 13.1868 - tn: 93958.7143 - fn: 67.5934 - accuracy: 0.9992 - precision: 0.8799 - recall: 0.6127 - auc: 0.9202 - val_loss: 0.0034 - val_tp: 61.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8472 - val_recall: 0.7625 - val_auc: 0.9247
Epoch 13/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0043 - tp: 98.0769 - fp: 16.0330 - tn: 93961.6703 - fn: 64.7912 - accuracy: 0.9992 - precision: 0.8536 - recall: 0.6112 - auc: 0.9154 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 9.0000 - val_tn: 45480.0000 - val_fn: 21.0000 - val_accuracy: 0.9993 - val_precision: 0.8676 - val_recall: 0.7375 - val_auc: 0.9247
Epoch 14/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0050 - tp: 93.5495 - fp: 15.4286 - tn: 93961.4615 - fn: 70.1319 - accuracy: 0.9991 - precision: 0.8590 - recall: 0.5563 - auc: 0.8916 - val_loss: 0.0033 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247
Epoch 15/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0042 - tp: 90.2198 - fp: 15.4286 - tn: 93968.0989 - fn: 66.8242 - accuracy: 0.9992 - precision: 0.8524 - recall: 0.5813 - auc: 0.9270 - val_loss: 0.0033 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247
Epoch 16/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 96.4835 - fp: 14.7363 - tn: 93960.4396 - fn: 68.9121 - accuracy: 0.9991 - precision: 0.8754 - recall: 0.5727 - auc: 0.9218 - val_loss: 0.0033 - val_tp: 62.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8611 - val_recall: 0.7750 - val_auc: 0.9247
Epoch 17/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0044 - tp: 98.7033 - fp: 16.4835 - tn: 93958.9231 - fn: 66.4615 - accuracy: 0.9991 - precision: 0.8674 - recall: 0.5985 - auc: 0.9108 - val_loss: 0.0032 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247
Restoring model weights from the end of the best epoch.
Epoch 00017: early stopping

Periksa riwayat pelatihan

Di bagian ini, Anda akan membuat plot tentang akurasi dan kerugian model Anda pada set pelatihan dan validasi. Ini berguna untuk memeriksa overfitting, yang dapat Anda pelajari lebih lanjut di tutorial ini.

Selain itu, Anda dapat membuat plot ini untuk salah satu metrik yang Anda buat di atas. Negatif palsu disertakan sebagai contoh.

def plot_metrics(history):
  metrics = ['loss', 'auc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

png

Evaluasi metrik

Anda dapat menggunakan matriks kebingungan untuk meringkas label sebenarnya vs. yang diprediksi di mana sumbu X adalah label yang diprediksi dan sumbu Y adalah label sebenarnya.

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

Evaluasi model Anda pada set data pengujian dan tampilkan hasilnya untuk metrik yang Anda buat di atas.

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.0034731989726424217
tp :  56.0
fp :  12.0
tn :  56855.0
fn :  39.0
accuracy :  0.9991046786308289
precision :  0.8235294222831726
recall :  0.5894736647605896
auc :  0.9418253898620605

Legitimate Transactions Detected (True Negatives):  56855
Legitimate Transactions Incorrectly Detected (False Positives):  12
Fraudulent Transactions Missed (False Negatives):  39
Fraudulent Transactions Detected (True Positives):  56
Total Fraudulent Transactions:  95

png

Jika model telah memprediksi semuanya dengan sempurna, ini akan menjadi matriks diagonal di mana nilai dari diagonal utama, yang menunjukkan prediksi yang salah, akan menjadi nol. Dalam kasus ini, matriks menunjukkan bahwa Anda memiliki relatif sedikit positif palsu, yang berarti bahwa ada relatif sedikit transaksi sah yang salah ditandai. Namun, Anda mungkin ingin memiliki lebih sedikit negatif palsu meskipun harus menambah jumlah positif palsu. Pertukaran ini mungkin lebih disukai karena negatif palsu akan memungkinkan transaksi penipuan terjadi, sedangkan positif palsu dapat menyebabkan email dikirim ke pelanggan untuk meminta mereka memverifikasi aktivitas kartu mereka.

Plot ROC

Sekarang plot ROC . Plot ini berguna karena sekilas menunjukkan kisaran performa yang dapat dicapai model hanya dengan menyetel ambang keluaran.

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa4786ca748>

png

Sepertinya ketepatannya relatif tinggi, tetapi recall dan area di bawah kurva KOP (AUC) tidak setinggi yang Anda inginkan. Pengklasifikasi sering menghadapi tantangan saat mencoba memaksimalkan presisi dan perolehan, yang terutama terjadi saat bekerja dengan set data yang tidak seimbang. Penting untuk mempertimbangkan biaya dari berbagai jenis kesalahan dalam konteks masalah yang Anda pedulikan. Dalam contoh ini, negatif palsu (transaksi curang terlewat) mungkin memiliki biaya finansial, sedangkan positif palsu (transaksi salah ditandai sebagai penipuan) dapat menurunkan kebahagiaan pengguna.

Bobot kelas

Hitung bobot kelas

Sasarannya adalah untuk mengidentifikasi transaksi penipuan, tetapi Anda tidak memiliki terlalu banyak sampel positif untuk dikerjakan, jadi Anda ingin agar pengklasifikasi memberi bobot pada beberapa contoh yang tersedia. Anda dapat melakukan ini dengan meneruskan bobot Keras untuk setiap kelas melalui parameter. Ini akan menyebabkan model "lebih memperhatikan" contoh dari kelas yang kurang terwakili.

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

Latih model dengan bobot kelas

Sekarang coba latih ulang dan evaluasi model dengan bobot kelas untuk melihat bagaimana hal itu memengaruhi prediksi.

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 3s 15ms/step - loss: 3.7596 - tp: 56.9890 - fp: 76.4286 - tn: 150769.4286 - fn: 199.7253 - accuracy: 0.9983 - precision: 0.4816 - recall: 0.2595 - auc: 0.7156 - val_loss: 0.0096 - val_tp: 2.0000 - val_fp: 1.0000 - val_tn: 45488.0000 - val_fn: 78.0000 - val_accuracy: 0.9983 - val_precision: 0.6667 - val_recall: 0.0250 - val_auc: 0.8999
Epoch 2/100
90/90 [==============================] - 1s 7ms/step - loss: 1.5426 - tp: 57.2308 - fp: 247.2088 - tn: 93722.7582 - fn: 113.3736 - accuracy: 0.9964 - precision: 0.1813 - recall: 0.2976 - auc: 0.8189 - val_loss: 0.0089 - val_tp: 54.0000 - val_fp: 18.0000 - val_tn: 45471.0000 - val_fn: 26.0000 - val_accuracy: 0.9990 - val_precision: 0.7500 - val_recall: 0.6750 - val_auc: 0.9306
Epoch 3/100
90/90 [==============================] - 1s 7ms/step - loss: 0.8711 - tp: 95.8352 - fp: 494.1209 - tn: 93479.8681 - fn: 70.7473 - accuracy: 0.9943 - precision: 0.1692 - recall: 0.6059 - auc: 0.8912 - val_loss: 0.0121 - val_tp: 66.0000 - val_fp: 32.0000 - val_tn: 45457.0000 - val_fn: 14.0000 - val_accuracy: 0.9990 - val_precision: 0.6735 - val_recall: 0.8250 - val_auc: 0.9426
Epoch 4/100
90/90 [==============================] - 1s 7ms/step - loss: 0.6835 - tp: 108.5165 - fp: 794.1648 - tn: 93183.5055 - fn: 54.3846 - accuracy: 0.9912 - precision: 0.1191 - recall: 0.6530 - auc: 0.8987 - val_loss: 0.0163 - val_tp: 67.0000 - val_fp: 54.0000 - val_tn: 45435.0000 - val_fn: 13.0000 - val_accuracy: 0.9985 - val_precision: 0.5537 - val_recall: 0.8375 - val_auc: 0.9556
Epoch 5/100
90/90 [==============================] - 1s 7ms/step - loss: 0.4713 - tp: 126.3626 - fp: 1149.3407 - tn: 92827.5275 - fn: 37.3407 - accuracy: 0.9878 - precision: 0.0992 - recall: 0.7828 - auc: 0.9329 - val_loss: 0.0214 - val_tp: 67.0000 - val_fp: 95.0000 - val_tn: 45394.0000 - val_fn: 13.0000 - val_accuracy: 0.9976 - val_precision: 0.4136 - val_recall: 0.8375 - val_auc: 0.9588
Epoch 6/100
90/90 [==============================] - 1s 7ms/step - loss: 0.4194 - tp: 125.7912 - fp: 1550.7253 - tn: 92430.8791 - fn: 33.1758 - accuracy: 0.9837 - precision: 0.0769 - recall: 0.7990 - auc: 0.9373 - val_loss: 0.0270 - val_tp: 67.0000 - val_fp: 147.0000 - val_tn: 45342.0000 - val_fn: 13.0000 - val_accuracy: 0.9965 - val_precision: 0.3131 - val_recall: 0.8375 - val_auc: 0.9626
Epoch 7/100
90/90 [==============================] - 1s 7ms/step - loss: 0.4226 - tp: 127.0659 - fp: 2000.6374 - tn: 91978.2857 - fn: 34.5824 - accuracy: 0.9788 - precision: 0.0567 - recall: 0.7672 - auc: 0.9351 - val_loss: 0.0348 - val_tp: 67.0000 - val_fp: 224.0000 - val_tn: 45265.0000 - val_fn: 13.0000 - val_accuracy: 0.9948 - val_precision: 0.2302 - val_recall: 0.8375 - val_auc: 0.9656
Epoch 8/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2990 - tp: 137.7473 - fp: 2564.6154 - tn: 91411.7363 - fn: 26.4725 - accuracy: 0.9729 - precision: 0.0528 - recall: 0.8483 - auc: 0.9609 - val_loss: 0.0457 - val_tp: 69.0000 - val_fp: 406.0000 - val_tn: 45083.0000 - val_fn: 11.0000 - val_accuracy: 0.9908 - val_precision: 0.1453 - val_recall: 0.8625 - val_auc: 0.9691
Epoch 9/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3165 - tp: 125.0330 - fp: 3192.7473 - tn: 90795.8462 - fn: 26.9451 - accuracy: 0.9662 - precision: 0.0375 - recall: 0.8237 - auc: 0.9518 - val_loss: 0.0568 - val_tp: 69.0000 - val_fp: 654.0000 - val_tn: 44835.0000 - val_fn: 11.0000 - val_accuracy: 0.9854 - val_precision: 0.0954 - val_recall: 0.8625 - val_auc: 0.9703
Epoch 10/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2917 - tp: 139.2527 - fp: 3709.6154 - tn: 90270.3626 - fn: 21.3407 - accuracy: 0.9606 - precision: 0.0356 - recall: 0.8714 - auc: 0.9546 - val_loss: 0.0640 - val_tp: 70.0000 - val_fp: 755.0000 - val_tn: 44734.0000 - val_fn: 10.0000 - val_accuracy: 0.9832 - val_precision: 0.0848 - val_recall: 0.8750 - val_auc: 0.9713
Epoch 11/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3109 - tp: 155.7363 - fp: 3838.3736 - tn: 90123.9670 - fn: 22.4945 - accuracy: 0.9590 - precision: 0.0415 - recall: 0.8730 - auc: 0.9555 - val_loss: 0.0703 - val_tp: 71.0000 - val_fp: 835.0000 - val_tn: 44654.0000 - val_fn: 9.0000 - val_accuracy: 0.9815 - val_precision: 0.0784 - val_recall: 0.8875 - val_auc: 0.9719
Epoch 12/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3007 - tp: 134.9121 - fp: 4045.0549 - tn: 89938.1758 - fn: 22.4286 - accuracy: 0.9566 - precision: 0.0320 - recall: 0.8712 - auc: 0.9494 - val_loss: 0.0756 - val_tp: 72.0000 - val_fp: 910.0000 - val_tn: 44579.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0733 - val_recall: 0.9000 - val_auc: 0.9716
Epoch 13/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2083 - tp: 143.2308 - fp: 4041.6593 - tn: 89938.9670 - fn: 16.7143 - accuracy: 0.9567 - precision: 0.0360 - recall: 0.9154 - auc: 0.9734 - val_loss: 0.0765 - val_tp: 72.0000 - val_fp: 916.0000 - val_tn: 44573.0000 - val_fn: 8.0000 - val_accuracy: 0.9797 - val_precision: 0.0729 - val_recall: 0.9000 - val_auc: 0.9726
Epoch 14/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2448 - tp: 143.6703 - fp: 4119.2967 - tn: 89857.6154 - fn: 19.9890 - accuracy: 0.9562 - precision: 0.0341 - recall: 0.8944 - auc: 0.9655 - val_loss: 0.0811 - val_tp: 72.0000 - val_fp: 992.0000 - val_tn: 44497.0000 - val_fn: 8.0000 - val_accuracy: 0.9781 - val_precision: 0.0677 - val_recall: 0.9000 - val_auc: 0.9724
Epoch 15/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2362 - tp: 141.0769 - fp: 4205.1429 - tn: 89776.9670 - fn: 17.3846 - accuracy: 0.9545 - precision: 0.0316 - recall: 0.8889 - auc: 0.9665 - val_loss: 0.0835 - val_tp: 72.0000 - val_fp: 1019.0000 - val_tn: 44470.0000 - val_fn: 8.0000 - val_accuracy: 0.9775 - val_precision: 0.0660 - val_recall: 0.9000 - val_auc: 0.9729
Epoch 16/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2100 - tp: 135.4945 - fp: 4278.8242 - tn: 89707.5495 - fn: 18.7033 - accuracy: 0.9542 - precision: 0.0289 - recall: 0.8980 - auc: 0.9717 - val_loss: 0.0922 - val_tp: 72.0000 - val_fp: 1117.0000 - val_tn: 44372.0000 - val_fn: 8.0000 - val_accuracy: 0.9753 - val_precision: 0.0606 - val_recall: 0.9000 - val_auc: 0.9728
Epoch 17/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2061 - tp: 147.9451 - fp: 4488.5604 - tn: 89486.7912 - fn: 17.2747 - accuracy: 0.9519 - precision: 0.0328 - recall: 0.9026 - auc: 0.9754 - val_loss: 0.0911 - val_tp: 72.0000 - val_fp: 1104.0000 - val_tn: 44385.0000 - val_fn: 8.0000 - val_accuracy: 0.9756 - val_precision: 0.0612 - val_recall: 0.9000 - val_auc: 0.9730
Epoch 18/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3032 - tp: 143.0989 - fp: 4367.9890 - tn: 89610.3956 - fn: 19.0879 - accuracy: 0.9529 - precision: 0.0312 - recall: 0.8782 - auc: 0.9486 - val_loss: 0.0878 - val_tp: 72.0000 - val_fp: 1037.0000 - val_tn: 44452.0000 - val_fn: 8.0000 - val_accuracy: 0.9771 - val_precision: 0.0649 - val_recall: 0.9000 - val_auc: 0.9761
Epoch 19/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2713 - tp: 148.5165 - fp: 4079.8462 - tn: 89894.0110 - fn: 18.1978 - accuracy: 0.9565 - precision: 0.0361 - recall: 0.8861 - auc: 0.9635 - val_loss: 0.0868 - val_tp: 72.0000 - val_fp: 1011.0000 - val_tn: 44478.0000 - val_fn: 8.0000 - val_accuracy: 0.9776 - val_precision: 0.0665 - val_recall: 0.9000 - val_auc: 0.9762
Epoch 20/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2623 - tp: 144.5055 - fp: 3991.8681 - tn: 89984.6484 - fn: 19.5495 - accuracy: 0.9575 - precision: 0.0332 - recall: 0.8837 - auc: 0.9587 - val_loss: 0.0888 - val_tp: 72.0000 - val_fp: 1030.0000 - val_tn: 44459.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0653 - val_recall: 0.9000 - val_auc: 0.9761
Epoch 21/100
90/90 [==============================] - 1s 7ms/step - loss: 0.3103 - tp: 136.6154 - fp: 3966.2527 - tn: 90015.5934 - fn: 22.1099 - accuracy: 0.9577 - precision: 0.0309 - recall: 0.8330 - auc: 0.9373 - val_loss: 0.0886 - val_tp: 72.0000 - val_fp: 1010.0000 - val_tn: 44479.0000 - val_fn: 8.0000 - val_accuracy: 0.9777 - val_precision: 0.0665 - val_recall: 0.9000 - val_auc: 0.9764
Epoch 22/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2215 - tp: 141.6484 - fp: 3872.5055 - tn: 90108.2527 - fn: 18.1648 - accuracy: 0.9584 - precision: 0.0347 - recall: 0.8935 - auc: 0.9698 - val_loss: 0.0862 - val_tp: 72.0000 - val_fp: 972.0000 - val_tn: 44517.0000 - val_fn: 8.0000 - val_accuracy: 0.9785 - val_precision: 0.0690 - val_recall: 0.9000 - val_auc: 0.9776
Epoch 23/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2214 - tp: 141.6593 - fp: 3877.3626 - tn: 90102.6593 - fn: 18.8901 - accuracy: 0.9585 - precision: 0.0354 - recall: 0.8872 - auc: 0.9685 - val_loss: 0.0842 - val_tp: 72.0000 - val_fp: 938.0000 - val_tn: 44551.0000 - val_fn: 8.0000 - val_accuracy: 0.9792 - val_precision: 0.0713 - val_recall: 0.9000 - val_auc: 0.9777
Epoch 24/100
90/90 [==============================] - 1s 7ms/step - loss: 0.1873 - tp: 138.3956 - fp: 3647.7473 - tn: 90337.8681 - fn: 16.5604 - accuracy: 0.9611 - precision: 0.0340 - recall: 0.9088 - auc: 0.9743 - val_loss: 0.0843 - val_tp: 73.0000 - val_fp: 938.0000 - val_tn: 44551.0000 - val_fn: 7.0000 - val_accuracy: 0.9793 - val_precision: 0.0722 - val_recall: 0.9125 - val_auc: 0.9779
Epoch 25/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2038 - tp: 140.3407 - fp: 3553.5824 - tn: 90428.6813 - fn: 17.9670 - accuracy: 0.9620 - precision: 0.0383 - recall: 0.8946 - auc: 0.9739 - val_loss: 0.0854 - val_tp: 73.0000 - val_fp: 942.0000 - val_tn: 44547.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0719 - val_recall: 0.9125 - val_auc: 0.9777
Epoch 26/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2003 - tp: 151.4725 - fp: 3495.4725 - tn: 90479.3516 - fn: 14.2747 - accuracy: 0.9625 - precision: 0.0406 - recall: 0.9196 - auc: 0.9732 - val_loss: 0.0819 - val_tp: 73.0000 - val_fp: 895.0000 - val_tn: 44594.0000 - val_fn: 7.0000 - val_accuracy: 0.9802 - val_precision: 0.0754 - val_recall: 0.9125 - val_auc: 0.9778
Epoch 27/100
90/90 [==============================] - 1s 9ms/step - loss: 0.2111 - tp: 138.2857 - fp: 3438.2088 - tn: 90547.7253 - fn: 16.3516 - accuracy: 0.9635 - precision: 0.0382 - recall: 0.9021 - auc: 0.9690 - val_loss: 0.0865 - val_tp: 73.0000 - val_fp: 940.0000 - val_tn: 44549.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0721 - val_recall: 0.9125 - val_auc: 0.9775
Epoch 28/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2115 - tp: 146.6813 - fp: 3796.8022 - tn: 90178.5604 - fn: 18.5275 - accuracy: 0.9595 - precision: 0.0357 - recall: 0.8799 - auc: 0.9740 - val_loss: 0.0914 - val_tp: 73.0000 - val_fp: 996.0000 - val_tn: 44493.0000 - val_fn: 7.0000 - val_accuracy: 0.9780 - val_precision: 0.0683 - val_recall: 0.9125 - val_auc: 0.9774
Epoch 29/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2085 - tp: 142.0220 - fp: 3757.9341 - tn: 90222.3956 - fn: 18.2198 - accuracy: 0.9597 - precision: 0.0368 - recall: 0.8941 - auc: 0.9738 - val_loss: 0.0890 - val_tp: 73.0000 - val_fp: 962.0000 - val_tn: 44527.0000 - val_fn: 7.0000 - val_accuracy: 0.9787 - val_precision: 0.0705 - val_recall: 0.9125 - val_auc: 0.9775
Epoch 30/100
90/90 [==============================] - 1s 7ms/step - loss: 0.1559 - tp: 153.3077 - fp: 3571.6044 - tn: 90404.6264 - fn: 11.0330 - accuracy: 0.9617 - precision: 0.0420 - recall: 0.9426 - auc: 0.9811 - val_loss: 0.0852 - val_tp: 73.0000 - val_fp: 901.0000 - val_tn: 44588.0000 - val_fn: 7.0000 - val_accuracy: 0.9801 - val_precision: 0.0749 - val_recall: 0.9125 - val_auc: 0.9778
Epoch 31/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2635 - tp: 142.1209 - fp: 3672.6264 - tn: 90304.4725 - fn: 21.3516 - accuracy: 0.9613 - precision: 0.0376 - recall: 0.8706 - auc: 0.9641 - val_loss: 0.0930 - val_tp: 73.0000 - val_fp: 1011.0000 - val_tn: 44478.0000 - val_fn: 7.0000 - val_accuracy: 0.9777 - val_precision: 0.0673 - val_recall: 0.9125 - val_auc: 0.9773
Epoch 32/100
90/90 [==============================] - 1s 7ms/step - loss: 0.1858 - tp: 155.9121 - fp: 3632.0440 - tn: 90338.0769 - fn: 14.5385 - accuracy: 0.9609 - precision: 0.0410 - recall: 0.9196 - auc: 0.9791 - val_loss: 0.0897 - val_tp: 73.0000 - val_fp: 972.0000 - val_tn: 44517.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0699 - val_recall: 0.9125 - val_auc: 0.9774
Epoch 33/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2525 - tp: 152.4945 - fp: 3743.6264 - tn: 90226.1209 - fn: 18.3297 - accuracy: 0.9599 - precision: 0.0398 - recall: 0.8872 - auc: 0.9646 - val_loss: 0.0878 - val_tp: 73.0000 - val_fp: 930.0000 - val_tn: 44559.0000 - val_fn: 7.0000 - val_accuracy: 0.9794 - val_precision: 0.0728 - val_recall: 0.9125 - val_auc: 0.9774
Epoch 34/100
90/90 [==============================] - 1s 7ms/step - loss: 0.2286 - tp: 142.4615 - fp: 3510.0659 - tn: 90469.1209 - fn: 18.9231 - accuracy: 0.9626 - precision: 0.0379 - recall: 0.8608 - auc: 0.9694 - val_loss: 0.0839 - val_tp: 73.0000 - val_fp: 889.0000 - val_tn: 44600.0000 - val_fn: 7.0000 - val_accuracy: 0.9803 - val_precision: 0.0759 - val_recall: 0.9125 - val_auc: 0.9773
Restoring model weights from the end of the best epoch.
Epoch 00034: early stopping

Periksa riwayat pelatihan

plot_metrics(weighted_history)

png

Evaluasi metrik

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.08238842338323593
tp :  89.0
fp :  1186.0
tn :  55681.0
fn :  6.0
accuracy :  0.9790737628936768
precision :  0.06980392336845398
recall :  0.9368420839309692
auc :  0.98465895652771

Legitimate Transactions Detected (True Negatives):  55681
Legitimate Transactions Incorrectly Detected (False Positives):  1186
Fraudulent Transactions Missed (False Negatives):  6
Fraudulent Transactions Detected (True Positives):  89
Total Fraudulent Transactions:  95

png

Di sini Anda dapat melihat bahwa dengan bobot kelas, akurasi dan presisi lebih rendah karena terdapat lebih banyak positif palsu, tetapi sebaliknya recall dan AUC lebih tinggi karena model juga menemukan lebih banyak positif benar. Meskipun memiliki akurasi yang lebih rendah, model ini memiliki daya ingat yang lebih tinggi (dan mengidentifikasi lebih banyak transaksi curang). Tentu saja, ada biaya untuk kedua jenis kesalahan tersebut (Anda juga tidak ingin mengganggu pengguna dengan menandai terlalu banyak transaksi yang sah sebagai penipuan). Pertimbangkan dengan cermat kompromi antara berbagai jenis kesalahan ini untuk aplikasi Anda.

Plot ROC

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa4fc124e48>

png

Oversampling

Terlalu banyak kelas minoritas

Pendekatan terkait akan mengambil sampel ulang kumpulan data dengan mengambil sampel berlebihan dari kelas minoritas.

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

Menggunakan NumPy

Anda dapat menyeimbangkan kumpulan data secara manual dengan memilih jumlah indeks acak yang tepat dari contoh positif:

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181959, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363918, 29)

Menggunakan tf.data

Jika Anda menggunakan tf.data , cara termudah untuk menghasilkan contoh yang seimbang adalah memulai dengan kumpulan data positive dan negative , dan menggabungkannya. Lihat panduan tf.data untuk contoh lainnya.

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

Setiap set data menyediakan pasangan (feature, label) :

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [-2.19450702e+00  1.14157653e+00 -1.53868568e+00 -3.35790031e-01
 -8.45956971e-01 -1.57966490e+00 -1.68436379e+00  2.25259298e-01
  2.62655847e-01 -3.51815449e+00  2.45653756e+00 -3.62123216e+00
 -1.02812838e+00 -5.00000000e+00  1.79960408e+00 -3.72675038e+00
 -5.00000000e+00 -2.22257320e+00 -1.24126989e-02 -9.14798839e-01
  7.33686754e-01 -9.34731229e-02 -1.74683429e+00  4.44394133e-01
 -3.98093307e-02 -1.99985035e+00 -2.27011783e+00  1.83260825e-03
  5.69563522e-01]

Label:  1

Gabungkan keduanya bersama-sama menggunakan experimental.sample_from_datasets :

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.5

Untuk menggunakan kumpulan data ini, Anda memerlukan jumlah langkah per epoch.

Definisi "epoch" dalam hal ini kurang jelas. Katakanlah itu jumlah kelompok yang diperlukan untuk melihat setiap contoh negatif satu kali:

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

Latih data yang diambil sampelnya terlalu banyak

Sekarang coba latih model dengan kumpulan data sampel ulang alih-alih menggunakan bobot kelas untuk melihat bagaimana metode ini dibandingkan.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks=[early_stopping],
    validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 8s 24ms/step - loss: 0.8118 - tp: 98159.0036 - fp: 34959.3333 - tn: 165642.9498 - fn: 44913.3728 - accuracy: 0.7582 - precision: 0.6573 - recall: 0.5935 - auc: 0.7814 - val_loss: 0.1785 - val_tp: 70.0000 - val_fp: 1177.0000 - val_tn: 44312.0000 - val_fn: 10.0000 - val_accuracy: 0.9740 - val_precision: 0.0561 - val_recall: 0.8750 - val_auc: 0.9722
Epoch 2/100
278/278 [==============================] - 6s 21ms/step - loss: 0.2142 - tp: 126477.2616 - fp: 8422.5305 - tn: 134479.1362 - fn: 17333.7312 - accuracy: 0.9075 - precision: 0.9341 - recall: 0.8774 - auc: 0.9705 - val_loss: 0.1007 - val_tp: 71.0000 - val_fp: 919.0000 - val_tn: 44570.0000 - val_fn: 9.0000 - val_accuracy: 0.9796 - val_precision: 0.0717 - val_recall: 0.8875 - val_auc: 0.9718
Epoch 3/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1652 - tp: 128160.7455 - fp: 6029.4875 - tn: 137362.8602 - fn: 15159.5663 - accuracy: 0.9255 - precision: 0.9544 - recall: 0.8936 - auc: 0.9831 - val_loss: 0.0775 - val_tp: 71.0000 - val_fp: 802.0000 - val_tn: 44687.0000 - val_fn: 9.0000 - val_accuracy: 0.9822 - val_precision: 0.0813 - val_recall: 0.8875 - val_auc: 0.9727
Epoch 4/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1451 - tp: 129069.5376 - fp: 5223.8925 - tn: 138326.3763 - fn: 14092.8530 - accuracy: 0.9322 - precision: 0.9611 - recall: 0.9007 - auc: 0.9872 - val_loss: 0.0676 - val_tp: 71.0000 - val_fp: 756.0000 - val_tn: 44733.0000 - val_fn: 9.0000 - val_accuracy: 0.9832 - val_precision: 0.0859 - val_recall: 0.8875 - val_auc: 0.9766
Epoch 5/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1328 - tp: 130130.0502 - fp: 5189.8996 - tn: 138364.5878 - fn: 13028.1219 - accuracy: 0.9360 - precision: 0.9611 - recall: 0.9085 - auc: 0.9895 - val_loss: 0.0616 - val_tp: 72.0000 - val_fp: 763.0000 - val_tn: 44726.0000 - val_fn: 8.0000 - val_accuracy: 0.9831 - val_precision: 0.0862 - val_recall: 0.9000 - val_auc: 0.9748
Epoch 6/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1230 - tp: 131300.7706 - fp: 4925.0179 - tn: 138645.5125 - fn: 11841.3584 - accuracy: 0.9413 - precision: 0.9635 - recall: 0.9170 - auc: 0.9912 - val_loss: 0.0566 - val_tp: 72.0000 - val_fp: 754.0000 - val_tn: 44735.0000 - val_fn: 8.0000 - val_accuracy: 0.9833 - val_precision: 0.0872 - val_recall: 0.9000 - val_auc: 0.9759
Epoch 7/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1150 - tp: 132477.4875 - fp: 4822.2509 - tn: 138449.6989 - fn: 10963.2222 - accuracy: 0.9445 - precision: 0.9647 - recall: 0.9228 - auc: 0.9925 - val_loss: 0.0516 - val_tp: 72.0000 - val_fp: 711.0000 - val_tn: 44778.0000 - val_fn: 8.0000 - val_accuracy: 0.9842 - val_precision: 0.0920 - val_recall: 0.9000 - val_auc: 0.9777
Epoch 8/100
278/278 [==============================] - 6s 23ms/step - loss: 0.1069 - tp: 132886.1828 - fp: 4656.9570 - tn: 138848.8029 - fn: 10320.7168 - accuracy: 0.9474 - precision: 0.9660 - recall: 0.9275 - auc: 0.9936 - val_loss: 0.0462 - val_tp: 71.0000 - val_fp: 667.0000 - val_tn: 44822.0000 - val_fn: 9.0000 - val_accuracy: 0.9852 - val_precision: 0.0962 - val_recall: 0.8875 - val_auc: 0.9695
Epoch 9/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1006 - tp: 133579.5950 - fp: 4598.8602 - tn: 138524.1039 - fn: 10010.1004 - accuracy: 0.9489 - precision: 0.9668 - recall: 0.9298 - auc: 0.9944 - val_loss: 0.0419 - val_tp: 71.0000 - val_fp: 639.0000 - val_tn: 44850.0000 - val_fn: 9.0000 - val_accuracy: 0.9858 - val_precision: 0.1000 - val_recall: 0.8875 - val_auc: 0.9703
Epoch 10/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0955 - tp: 134632.2903 - fp: 4540.9355 - tn: 138577.5520 - fn: 8961.8817 - accuracy: 0.9521 - precision: 0.9676 - recall: 0.9358 - auc: 0.9950 - val_loss: 0.0396 - val_tp: 71.0000 - val_fp: 634.0000 - val_tn: 44855.0000 - val_fn: 9.0000 - val_accuracy: 0.9859 - val_precision: 0.1007 - val_recall: 0.8875 - val_auc: 0.9664
Epoch 11/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0908 - tp: 140213.1900 - fp: 5993.7849 - tn: 136973.6380 - fn: 3532.0466 - accuracy: 0.9663 - precision: 0.9588 - recall: 0.9747 - auc: 0.9955 - val_loss: 0.0359 - val_tp: 71.0000 - val_fp: 596.0000 - val_tn: 44893.0000 - val_fn: 9.0000 - val_accuracy: 0.9867 - val_precision: 0.1064 - val_recall: 0.8875 - val_auc: 0.9669
Epoch 12/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0857 - tp: 140261.6918 - fp: 5979.1470 - tn: 137657.2294 - fn: 2814.5914 - accuracy: 0.9691 - precision: 0.9589 - recall: 0.9800 - auc: 0.9959 - val_loss: 0.0337 - val_tp: 71.0000 - val_fp: 580.0000 - val_tn: 44909.0000 - val_fn: 9.0000 - val_accuracy: 0.9871 - val_precision: 0.1091 - val_recall: 0.8875 - val_auc: 0.9677
Epoch 13/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0818 - tp: 140796.7133 - fp: 5945.5305 - tn: 137458.5520 - fn: 2511.8638 - accuracy: 0.9704 - precision: 0.9596 - recall: 0.9822 - auc: 0.9962 - val_loss: 0.0318 - val_tp: 71.0000 - val_fp: 558.0000 - val_tn: 44931.0000 - val_fn: 9.0000 - val_accuracy: 0.9876 - val_precision: 0.1129 - val_recall: 0.8875 - val_auc: 0.9682
Epoch 14/100
278/278 [==============================] - 7s 24ms/step - loss: 0.0793 - tp: 140997.9176 - fp: 6076.8746 - tn: 137562.6846 - fn: 2075.1828 - accuracy: 0.9714 - precision: 0.9586 - recall: 0.9853 - auc: 0.9964 - val_loss: 0.0303 - val_tp: 71.0000 - val_fp: 555.0000 - val_tn: 44934.0000 - val_fn: 9.0000 - val_accuracy: 0.9876 - val_precision: 0.1134 - val_recall: 0.8875 - val_auc: 0.9687
Epoch 15/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0759 - tp: 141966.7312 - fp: 6100.1147 - tn: 136957.0108 - fn: 1688.8029 - accuracy: 0.9729 - precision: 0.9589 - recall: 0.9883 - auc: 0.9966 - val_loss: 0.0292 - val_tp: 71.0000 - val_fp: 541.0000 - val_tn: 44948.0000 - val_fn: 9.0000 - val_accuracy: 0.9879 - val_precision: 0.1160 - val_recall: 0.8875 - val_auc: 0.9692
Epoch 16/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0727 - tp: 141879.4731 - fp: 6020.2007 - tn: 137272.9140 - fn: 1540.0717 - accuracy: 0.9736 - precision: 0.9594 - recall: 0.9891 - auc: 0.9968 - val_loss: 0.0276 - val_tp: 71.0000 - val_fp: 504.0000 - val_tn: 44985.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 0.1235 - val_recall: 0.8875 - val_auc: 0.9697
Epoch 17/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0704 - tp: 142140.7563 - fp: 6019.2258 - tn: 137225.3978 - fn: 1327.2796 - accuracy: 0.9745 - precision: 0.9597 - recall: 0.9907 - auc: 0.9969 - val_loss: 0.0269 - val_tp: 71.0000 - val_fp: 495.0000 - val_tn: 44994.0000 - val_fn: 9.0000 - val_accuracy: 0.9889 - val_precision: 0.1254 - val_recall: 0.8875 - val_auc: 0.9699
Restoring model weights from the end of the best epoch.
Epoch 00017: early stopping

Jika proses pelatihan mempertimbangkan seluruh kumpulan data pada setiap pembaruan gradien, pengambilan sampel berlebihan ini pada dasarnya akan identik dengan pembobotan kelas.

Namun saat melatih model secara batch, seperti yang Anda lakukan di sini, data sampel berlebih memberikan sinyal gradien yang lebih halus: Alih-alih setiap contoh positif ditampilkan dalam satu batch dengan bobot besar, mereka ditampilkan dalam banyak batch berbeda setiap kali dengan berat kecil.

Sinyal gradien yang lebih halus ini memudahkan untuk melatih model.

Periksa riwayat pelatihan

Perhatikan bahwa distribusi metrik akan berbeda di sini, karena data pelatihan memiliki distribusi yang sangat berbeda dari data validasi dan pengujian.

plot_metrics(resampled_history)

png

Latih ulang

Karena pelatihan lebih mudah pada data yang seimbang, prosedur pelatihan di atas mungkin cepat selesai.

Jadi hentikan waktu untuk memberikan callbacks.EarlyStopping Kontrol yang lebih baik dari EarlyStopping kapan harus menghentikan pelatihan.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch=20,
    epochs=10*EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_ds))
Epoch 1/1000
20/20 [==============================] - 3s 59ms/step - loss: 1.6583 - tp: 2292.0000 - fp: 4139.7619 - tn: 52605.9524 - fn: 8961.7619 - accuracy: 0.8197 - precision: 0.3317 - recall: 0.1958 - auc: 0.7857 - val_loss: 0.5873 - val_tp: 9.0000 - val_fp: 13094.0000 - val_tn: 32395.0000 - val_fn: 71.0000 - val_accuracy: 0.7111 - val_precision: 6.8687e-04 - val_recall: 0.1125 - val_auc: 0.2616
Epoch 2/1000
20/20 [==============================] - 1s 27ms/step - loss: 1.0305 - tp: 4655.6667 - fp: 3775.2857 - tn: 7528.4286 - fn: 6471.0952 - accuracy: 0.5318 - precision: 0.5347 - recall: 0.3960 - auc: 0.4700 - val_loss: 0.5752 - val_tp: 45.0000 - val_fp: 12493.0000 - val_tn: 32996.0000 - val_fn: 35.0000 - val_accuracy: 0.7251 - val_precision: 0.0036 - val_recall: 0.5625 - val_auc: 0.7427
Epoch 3/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.6911 - tp: 7407.9048 - fp: 3609.9048 - tn: 7705.3333 - fn: 3707.3333 - accuracy: 0.6637 - precision: 0.6637 - recall: 0.6487 - auc: 0.6971 - val_loss: 0.5367 - val_tp: 73.0000 - val_fp: 10392.0000 - val_tn: 35097.0000 - val_fn: 7.0000 - val_accuracy: 0.7718 - val_precision: 0.0070 - val_recall: 0.9125 - val_auc: 0.9205
Epoch 4/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.5257 - tp: 8758.5238 - fp: 3223.3333 - tn: 8031.5714 - fn: 2417.0476 - accuracy: 0.7432 - precision: 0.7273 - recall: 0.7793 - auc: 0.8200 - val_loss: 0.4801 - val_tp: 73.0000 - val_fp: 7752.0000 - val_tn: 37737.0000 - val_fn: 7.0000 - val_accuracy: 0.8297 - val_precision: 0.0093 - val_recall: 0.9125 - val_auc: 0.9436
Epoch 5/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.4499 - tp: 9241.0952 - fp: 2761.1905 - tn: 8463.3333 - fn: 1964.8571 - accuracy: 0.7862 - precision: 0.7667 - recall: 0.8219 - auc: 0.8704 - val_loss: 0.4242 - val_tp: 72.0000 - val_fp: 5662.0000 - val_tn: 39827.0000 - val_fn: 8.0000 - val_accuracy: 0.8756 - val_precision: 0.0126 - val_recall: 0.9000 - val_auc: 0.9524
Epoch 6/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.4029 - tp: 9361.0000 - fp: 2311.8571 - tn: 8961.6190 - fn: 1796.0000 - accuracy: 0.8140 - precision: 0.7965 - recall: 0.8389 - auc: 0.8968 - val_loss: 0.3750 - val_tp: 71.0000 - val_fp: 4105.0000 - val_tn: 41384.0000 - val_fn: 9.0000 - val_accuracy: 0.9097 - val_precision: 0.0170 - val_recall: 0.8875 - val_auc: 0.9583
Epoch 7/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.3601 - tp: 9456.8571 - fp: 1812.6667 - tn: 9376.6190 - fn: 1784.3333 - accuracy: 0.8386 - precision: 0.8376 - recall: 0.8413 - auc: 0.9181 - val_loss: 0.3324 - val_tp: 70.0000 - val_fp: 3047.0000 - val_tn: 42442.0000 - val_fn: 10.0000 - val_accuracy: 0.9329 - val_precision: 0.0225 - val_recall: 0.8750 - val_auc: 0.9626
Epoch 8/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.3373 - tp: 9506.3810 - fp: 1623.5714 - tn: 9568.3333 - fn: 1732.1905 - accuracy: 0.8490 - precision: 0.8520 - recall: 0.8453 - auc: 0.9266 - val_loss: 0.2959 - val_tp: 69.0000 - val_fp: 2358.0000 - val_tn: 43131.0000 - val_fn: 11.0000 - val_accuracy: 0.9480 - val_precision: 0.0284 - val_recall: 0.8625 - val_auc: 0.9660
Epoch 9/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.3120 - tp: 9635.0000 - fp: 1347.7619 - tn: 9797.4286 - fn: 1650.2857 - accuracy: 0.8662 - precision: 0.8776 - recall: 0.8534 - auc: 0.9366 - val_loss: 0.2670 - val_tp: 69.0000 - val_fp: 1943.0000 - val_tn: 43546.0000 - val_fn: 11.0000 - val_accuracy: 0.9571 - val_precision: 0.0343 - val_recall: 0.8625 - val_auc: 0.9689
Epoch 10/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2923 - tp: 9632.5714 - fp: 1221.3333 - tn: 9978.5714 - fn: 1598.0000 - accuracy: 0.8725 - precision: 0.8864 - recall: 0.8559 - auc: 0.9446 - val_loss: 0.2417 - val_tp: 69.0000 - val_fp: 1609.0000 - val_tn: 43880.0000 - val_fn: 11.0000 - val_accuracy: 0.9644 - val_precision: 0.0411 - val_recall: 0.8625 - val_auc: 0.9706
Epoch 11/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2765 - tp: 9632.7619 - fp: 1069.8095 - tn: 10155.4762 - fn: 1572.4286 - accuracy: 0.8806 - precision: 0.8976 - recall: 0.8584 - auc: 0.9505 - val_loss: 0.2211 - val_tp: 69.0000 - val_fp: 1424.0000 - val_tn: 44065.0000 - val_fn: 11.0000 - val_accuracy: 0.9685 - val_precision: 0.0462 - val_recall: 0.8625 - val_auc: 0.9716
Epoch 12/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2624 - tp: 9721.6190 - fp: 949.2857 - tn: 10243.6190 - fn: 1515.9524 - accuracy: 0.8901 - precision: 0.9102 - recall: 0.8658 - auc: 0.9556 - val_loss: 0.2038 - val_tp: 69.0000 - val_fp: 1299.0000 - val_tn: 44190.0000 - val_fn: 11.0000 - val_accuracy: 0.9713 - val_precision: 0.0504 - val_recall: 0.8625 - val_auc: 0.9722
Epoch 13/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2491 - tp: 9862.2381 - fp: 851.1905 - tn: 10195.6667 - fn: 1521.3810 - accuracy: 0.8933 - precision: 0.9207 - recall: 0.8656 - auc: 0.9597 - val_loss: 0.1904 - val_tp: 70.0000 - val_fp: 1230.0000 - val_tn: 44259.0000 - val_fn: 10.0000 - val_accuracy: 0.9728 - val_precision: 0.0538 - val_recall: 0.8750 - val_auc: 0.9726
Epoch 14/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2384 - tp: 9752.6190 - fp: 806.5714 - tn: 10438.3333 - fn: 1432.9524 - accuracy: 0.9006 - precision: 0.9238 - recall: 0.8724 - auc: 0.9637 - val_loss: 0.1781 - val_tp: 70.0000 - val_fp: 1186.0000 - val_tn: 44303.0000 - val_fn: 10.0000 - val_accuracy: 0.9738 - val_precision: 0.0557 - val_recall: 0.8750 - val_auc: 0.9724
Epoch 15/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2332 - tp: 9727.1905 - fp: 787.1905 - tn: 10508.9524 - fn: 1407.1429 - accuracy: 0.9024 - precision: 0.9255 - recall: 0.8737 - auc: 0.9651 - val_loss: 0.1664 - val_tp: 70.0000 - val_fp: 1130.0000 - val_tn: 44359.0000 - val_fn: 10.0000 - val_accuracy: 0.9750 - val_precision: 0.0583 - val_recall: 0.8750 - val_auc: 0.9728
Epoch 16/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2242 - tp: 9812.1905 - fp: 731.9524 - tn: 10481.3333 - fn: 1405.0000 - accuracy: 0.9048 - precision: 0.9310 - recall: 0.8745 - auc: 0.9677 - val_loss: 0.1561 - val_tp: 70.0000 - val_fp: 1085.0000 - val_tn: 44404.0000 - val_fn: 10.0000 - val_accuracy: 0.9760 - val_precision: 0.0606 - val_recall: 0.8750 - val_auc: 0.9725
Epoch 17/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2197 - tp: 9864.8571 - fp: 691.0000 - tn: 10487.2381 - fn: 1387.3810 - accuracy: 0.9072 - precision: 0.9350 - recall: 0.8766 - auc: 0.9690 - val_loss: 0.1475 - val_tp: 70.0000 - val_fp: 1042.0000 - val_tn: 44447.0000 - val_fn: 10.0000 - val_accuracy: 0.9769 - val_precision: 0.0629 - val_recall: 0.8750 - val_auc: 0.9724
Epoch 18/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2110 - tp: 9788.9524 - fp: 651.8571 - tn: 10585.9524 - fn: 1403.7143 - accuracy: 0.9087 - precision: 0.9392 - recall: 0.8742 - auc: 0.9716 - val_loss: 0.1401 - val_tp: 70.0000 - val_fp: 1020.0000 - val_tn: 44469.0000 - val_fn: 10.0000 - val_accuracy: 0.9774 - val_precision: 0.0642 - val_recall: 0.8750 - val_auc: 0.9718
Epoch 19/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2035 - tp: 9779.1905 - fp: 619.4286 - tn: 10699.0952 - fn: 1332.7619 - accuracy: 0.9117 - precision: 0.9402 - recall: 0.8784 - auc: 0.9733 - val_loss: 0.1340 - val_tp: 70.0000 - val_fp: 1010.0000 - val_tn: 44479.0000 - val_fn: 10.0000 - val_accuracy: 0.9776 - val_precision: 0.0648 - val_recall: 0.8750 - val_auc: 0.9721
Epoch 20/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1993 - tp: 9782.3333 - fp: 592.7619 - tn: 10738.1905 - fn: 1317.1905 - accuracy: 0.9156 - precision: 0.9430 - recall: 0.8823 - auc: 0.9748 - val_loss: 0.1282 - val_tp: 70.0000 - val_fp: 999.0000 - val_tn: 44490.0000 - val_fn: 10.0000 - val_accuracy: 0.9779 - val_precision: 0.0655 - val_recall: 0.8750 - val_auc: 0.9721
Epoch 21/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1995 - tp: 9862.0476 - fp: 625.4762 - tn: 10605.0952 - fn: 1337.8571 - accuracy: 0.9119 - precision: 0.9394 - recall: 0.8805 - auc: 0.9747 - val_loss: 0.1228 - val_tp: 70.0000 - val_fp: 965.0000 - val_tn: 44524.0000 - val_fn: 10.0000 - val_accuracy: 0.9786 - val_precision: 0.0676 - val_recall: 0.8750 - val_auc: 0.9718
Epoch 22/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1936 - tp: 9804.0476 - fp: 582.0952 - tn: 10755.0000 - fn: 1289.3333 - accuracy: 0.9168 - precision: 0.9433 - recall: 0.8842 - auc: 0.9764 - val_loss: 0.1179 - val_tp: 70.0000 - val_fp: 944.0000 - val_tn: 44545.0000 - val_fn: 10.0000 - val_accuracy: 0.9791 - val_precision: 0.0690 - val_recall: 0.8750 - val_auc: 0.9713
Epoch 23/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1889 - tp: 9952.6667 - fp: 552.3333 - tn: 10589.3810 - fn: 1336.0952 - accuracy: 0.9155 - precision: 0.9472 - recall: 0.8812 - auc: 0.9774 - val_loss: 0.1146 - val_tp: 70.0000 - val_fp: 942.0000 - val_tn: 44547.0000 - val_fn: 10.0000 - val_accuracy: 0.9791 - val_precision: 0.0692 - val_recall: 0.8750 - val_auc: 0.9715
Epoch 24/1000
20/20 [==============================] - 1s 33ms/step - loss: 0.1828 - tp: 9880.4286 - fp: 540.2381 - tn: 10753.0000 - fn: 1256.8095 - accuracy: 0.9207 - precision: 0.9483 - recall: 0.8883 - auc: 0.9787 - val_loss: 0.1116 - val_tp: 71.0000 - val_fp: 961.0000 - val_tn: 44528.0000 - val_fn: 9.0000 - val_accuracy: 0.9787 - val_precision: 0.0688 - val_recall: 0.8875 - val_auc: 0.9712
Epoch 25/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1829 - tp: 9871.3333 - fp: 555.4286 - tn: 10747.9048 - fn: 1255.8095 - accuracy: 0.9200 - precision: 0.9470 - recall: 0.8879 - auc: 0.9791 - val_loss: 0.1082 - val_tp: 71.0000 - val_fp: 953.0000 - val_tn: 44536.0000 - val_fn: 9.0000 - val_accuracy: 0.9789 - val_precision: 0.0693 - val_recall: 0.8875 - val_auc: 0.9713
Restoring model weights from the end of the best epoch.
Epoch 00025: early stopping

Periksa kembali riwayat pelatihan

plot_metrics(resampled_history)

png

Evaluasi metrik

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
loss :  0.16524264216423035
tp :  91.0
fp :  1376.0
tn :  55491.0
fn :  4.0
accuracy :  0.9757733345031738
precision :  0.06203135475516319
recall :  0.9578947424888611
auc :  0.9829339385032654

Legitimate Transactions Detected (True Negatives):  55491
Legitimate Transactions Incorrectly Detected (False Positives):  1376
Fraudulent Transactions Missed (False Negatives):  4
Fraudulent Transactions Detected (True Positives):  91
Total Fraudulent Transactions:  95

png

Plot ROC

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa0682848d0>

png

Menerapkan tutorial ini untuk masalah Anda

Klasifikasi data yang tidak seimbang merupakan tugas yang sulit karena hanya ada sedikit sampel untuk dipelajari. Anda harus selalu memulai dengan data terlebih dahulu dan melakukan yang terbaik untuk mengumpulkan sampel sebanyak mungkin dan memberikan pemikiran yang substansial tentang fitur apa yang mungkin relevan sehingga model tersebut dapat memaksimalkan kelas minoritas Anda. Pada titik tertentu model Anda mungkin kesulitan untuk meningkatkan dan memberikan hasil yang Anda inginkan, jadi penting untuk mengingat konteks masalah Anda dan trade off antara berbagai jenis kesalahan.