![]() | ![]() | ![]() | ![]() |
Tutorial ini menunjukkan cara mengklasifikasikan kumpulan data yang sangat tidak seimbang di mana jumlah contoh di satu kelas jauh lebih banyak daripada contoh di kelas lain. Anda akan bekerja dengan kumpulan data Deteksi Penipuan Kartu Kredit yang dihosting di Kaggle. Tujuannya, untuk mendeteksi hanya 492 transaksi penipuan dari total 284.807 transaksi. Anda akan menggunakan Keras untuk menentukan model dan bobot kelas untuk membantu model belajar dari data yang tidak seimbang. .
Tutorial ini berisi kode lengkap untuk:
- Muat file CSV menggunakan Pandas.
- Buat set pelatihan, validasi, dan pengujian.
- Tentukan dan latih model menggunakan Keras (termasuk menyetel bobot kelas).
- Evaluasi model menggunakan berbagai metrik (termasuk presisi dan perolehan).
- Cobalah teknik umum untuk menangani data yang tidak seimbang seperti:
- Pembobotan kelas
- Oversampling
Mendirikan
import tensorflow as tf
from tensorflow import keras
import os
import tempfile
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
Pemrosesan dan eksplorasi data
Unduh kumpulan data Penipuan Kartu Kredit Kaggle
Pandas adalah pustaka Python dengan banyak utilitas bermanfaat untuk memuat dan bekerja dengan data terstruktur dan dapat digunakan untuk mengunduh CSV ke dalam kerangka data.
file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
Periksa ketidakseimbangan label kelas
Mari kita lihat ketidakseimbangan dataset:
neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n Total: {}\n Positive: {} ({:.2f}% of total)\n'.format(
total, pos, 100 * pos / total))
Examples: Total: 284807 Positive: 492 (0.17% of total)
Ini menunjukkan sebagian kecil dari sampel positif.
Bersihkan, pisahkan, dan normalkan data
Data mentah memiliki beberapa masalah. Pertama, kolom Time
dan Amount
terlalu bervariasi untuk digunakan secara langsung. Jatuhkan kolom Time
(karena tidak jelas apa artinya) dan ambil log dari kolom Amount
untuk mengurangi jangkauannya.
cleaned_df = raw_df.copy()
# You don't want the `Time` column.
cleaned_df.pop('Time')
# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)
Pisahkan kumpulan data menjadi set pelatihan, validasi, dan pengujian. Set validasi digunakan selama penyesuaian model untuk mengevaluasi kerugian dan metrik apa pun, namun model tersebut tidak cocok dengan data ini. Set pengujian sama sekali tidak digunakan selama fase pelatihan dan hanya digunakan di bagian akhir untuk mengevaluasi seberapa baik model menggeneralisasi data baru. Ini terutama penting dengan kumpulan data yang tidak seimbang di mana overfitting menjadi perhatian yang signifikan dari kurangnya data pelatihan.
# Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)
# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))
train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)
Normalisasi fitur masukan menggunakan sklearn StandardScaler. Ini akan menetapkan mean menjadi 0 dan deviasi standar menjadi 1.
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)
train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)
print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)
print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,) Validation labels shape: (45569,) Test labels shape: (56962,) Training features shape: (182276, 29) Validation features shape: (45569, 29) Test features shape: (56962, 29)
Lihatlah distribusi datanya
Selanjutnya bandingkan distribusi contoh positif dan negatif melalui beberapa fitur. Pertanyaan bagus untuk ditanyakan pada diri Anda saat ini adalah:
- Apakah distribusi ini masuk akal?
- Iya. Anda telah menormalkan input dan ini sebagian besar terkonsentrasi di kisaran
+/- 2
.
- Iya. Anda telah menormalkan input dan ini sebagian besar terkonsentrasi di kisaran
- Bisakah Anda melihat perbedaan antara distribusi?
- Ya, contoh positif mengandung tingkat nilai ekstrim yang jauh lebih tinggi.
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)
sns.jointplot(pos_df['V5'], pos_df['V6'],
kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")
sns.jointplot(neg_df['V5'], neg_df['V6'],
kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning /home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning
Tentukan model dan metrik
Mendefinisikan sebuah fungsi yang menciptakan jaringan saraf yang sederhana dengan lapisan terhubung densly tersembunyi, sebuah jebolan lapisan untuk mengurangi overfitting, dan output lapisan sigmoid bahwa pengembalian kemungkinan transaksi menjadi penipuan:
METRICS = [
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.BinaryAccuracy(name='accuracy'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name='auc'),
]
def make_model(metrics=METRICS, output_bias=None):
if output_bias is not None:
output_bias = tf.keras.initializers.Constant(output_bias)
model = keras.Sequential([
keras.layers.Dense(
16, activation='relu',
input_shape=(train_features.shape[-1],)),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid',
bias_initializer=output_bias),
])
model.compile(
optimizer=keras.optimizers.Adam(lr=1e-3),
loss=keras.losses.BinaryCrossentropy(),
metrics=metrics)
return model
Memahami metrik yang berguna
Perhatikan bahwa ada beberapa metrik yang ditentukan di atas yang dapat dihitung oleh model yang akan berguna saat mengevaluasi kinerja.
- Negatif palsu dan positif palsu adalah contoh yang salah diklasifikasikan
- Benar negatif dan positif sejati adalah contoh yang benar diklasifikasikan
- Akurasi adalah persentase contoh yang diklasifikasikan dengan benar> $ \ frac {\ text {sampel benar}} {\ text {total sampel}} $
- Presisi adalah persentase prediksi positif yang diklasifikasikan dengan benar> $ \ frac {\ text {positif benar}} {\ text {positif benar + positif palsu}} $
- Perolehan adalah persentase positif sebenarnya yang diklasifikasikan dengan benar> $ \ frac {\ text {positif benar}} {\ text {positif benar + negatif palsu}} $
- ABK mengacu pada Area di Bawah Kurva Kurva Karakteristik Operasi Penerima (ROC-AUC). Metrik ini sama dengan probabilitas bahwa pengklasifikasi akan memberi peringkat sampel positif acak lebih tinggi daripada sampel negatif acak.
Baca lebih banyak:
Model dasar
Bangun modelnya
Sekarang buat dan latih model Anda menggunakan fungsi yang telah ditentukan sebelumnya. Perhatikan bahwa model tersebut cocok menggunakan lebih besar dari ukuran kumpulan default 2048, hal ini penting untuk memastikan bahwa setiap batch memiliki peluang yang layak untuk memuat beberapa sampel positif. Jika ukuran batch terlalu kecil, kemungkinan besar mereka tidak memiliki transaksi penipuan untuk dipelajari.
EPOCHS = 100
BATCH_SIZE = 2048
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_auc',
verbose=1,
patience=10,
mode='max',
restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 480 _________________________________________________________________ dropout (Dropout) (None, 16) 0 _________________________________________________________________ dense_1 (Dense) (None, 1) 17 ================================================================= Total params: 497 Trainable params: 497 Non-trainable params: 0 _________________________________________________________________
Uji coba model:
model.predict(train_features[:10])
array([[0.6983068 ], [0.7546284 ], [0.73785573], [0.7908986 ], [0.51232255], [0.752192 ], [0.7387281 ], [0.9410955 ], [0.809352 ], [0.6911539 ]], dtype=float32)
Opsional: Setel bias awal yang benar.
Tebakan awal ini tidak bagus. Anda tahu bahwa set data tidak seimbang. Setel bias lapisan keluaran untuk mencerminkan itu (Lihat: Resep untuk Melatih Jaringan Syaraf Tiruan: "init well" ). Ini dapat membantu konvergensi awal.
Dengan inisialisasi bias default, kerugian seharusnya sekitar math.log(2) = 0.69314
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 1.5998
Bias yang benar untuk ditetapkan dapat diturunkan dari:
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])
Tetapkan itu sebagai bias awal, dan model akan memberikan tebakan awal yang lebih masuk akal.
Ini harus dekat: pos/total = 0.0018
model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
array([[0.00168876], [0.00081124], [0.00087036], [0.00241473], [0.00133016], [0.00121771], [0.00079989], [0.00079692], [0.00257652], [0.00104385]], dtype=float32)
Dengan inisialisasi ini kerugian awal harus kira-kira:
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0174
Kerugian awal ini sekitar 50 kali lebih sedikit dibandingkan jika terjadi inisialisasi yang naif.
Dengan cara ini, model tidak perlu menghabiskan beberapa periode pertama hanya untuk mempelajari bahwa contoh positif tidak mungkin terjadi. Ini juga memudahkan untuk membaca plot kerugian selama pelatihan.
Periksa bobot awal
Untuk membuat berbagai pelatihan berjalan lebih sebanding, simpan bobot model awal ini dalam file checkpoint, dan muat ke setiap model sebelum pelatihan.
initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)
Konfirmasikan bahwa perbaikan bias membantu
Sebelum melanjutkan, konfirmasikan dengan cepat bahwa inisialisasi bias yang cermat benar-benar membantu.
Latih model untuk 20 epoch, dengan dan tanpa inisialisasi yang cermat ini, dan bandingkan kerugiannya:
model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=20,
validation_data=(val_features, val_labels),
verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=20,
validation_data=(val_features, val_labels),
verbose=0)
def plot_loss(history, label, n):
# Use a log scale on y-axis to show the wide range of values.
plt.semilogy(history.epoch, history.history['loss'],
color=colors[n], label='Train ' + label)
plt.semilogy(history.epoch, history.history['val_loss'],
color=colors[n], label='Val ' + label,
linestyle="--")
plt.xlabel('Epoch')
plt.ylabel('Loss')
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)
Gambar di atas menjelaskan: Dalam hal kehilangan validasi, pada masalah ini, inisialisasi yang cermat ini memberikan keuntungan yang jelas.
Latih modelnya
model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels))
Epoch 1/100 90/90 [==============================] - 3s 15ms/step - loss: 0.0186 - tp: 64.0000 - fp: 25.1978 - tn: 139431.9780 - fn: 188.3956 - accuracy: 0.9985 - precision: 0.7214 - recall: 0.3037 - auc: 0.6752 - val_loss: 0.0109 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 45489.0000 - val_fn: 80.0000 - val_accuracy: 0.9982 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.6977 Epoch 2/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0085 - tp: 25.4505 - fp: 5.6703 - tn: 93973.1538 - fn: 136.2967 - accuracy: 0.9985 - precision: 0.6137 - recall: 0.1108 - auc: 0.8194 - val_loss: 0.0051 - val_tp: 27.0000 - val_fp: 9.0000 - val_tn: 45480.0000 - val_fn: 53.0000 - val_accuracy: 0.9986 - val_precision: 0.7500 - val_recall: 0.3375 - val_auc: 0.9184 Epoch 3/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0057 - tp: 68.4505 - fp: 9.9231 - tn: 93970.2198 - fn: 91.9780 - accuracy: 0.9989 - precision: 0.8896 - recall: 0.4278 - auc: 0.9133 - val_loss: 0.0043 - val_tp: 44.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 36.0000 - val_accuracy: 0.9990 - val_precision: 0.8148 - val_recall: 0.5500 - val_auc: 0.9185 Epoch 4/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0051 - tp: 86.9451 - fp: 17.4505 - tn: 93959.1429 - fn: 77.0330 - accuracy: 0.9990 - precision: 0.8422 - recall: 0.5397 - auc: 0.9191 - val_loss: 0.0040 - val_tp: 51.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 29.0000 - val_accuracy: 0.9991 - val_precision: 0.8361 - val_recall: 0.6375 - val_auc: 0.9248 Epoch 5/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0053 - tp: 83.8571 - fp: 13.0110 - tn: 93965.0879 - fn: 78.6154 - accuracy: 0.9990 - precision: 0.8890 - recall: 0.5114 - auc: 0.9212 - val_loss: 0.0039 - val_tp: 55.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 25.0000 - val_accuracy: 0.9992 - val_precision: 0.8462 - val_recall: 0.6875 - val_auc: 0.9247 Epoch 6/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0049 - tp: 83.9890 - fp: 11.9451 - tn: 93976.1648 - fn: 68.4725 - accuracy: 0.9992 - precision: 0.8793 - recall: 0.5693 - auc: 0.9022 - val_loss: 0.0038 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8406 - val_recall: 0.7250 - val_auc: 0.9247 Epoch 7/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0053 - tp: 89.0549 - fp: 19.6484 - tn: 93958.4835 - fn: 73.3846 - accuracy: 0.9991 - precision: 0.8187 - recall: 0.5711 - auc: 0.8824 - val_loss: 0.0036 - val_tp: 53.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 27.0000 - val_accuracy: 0.9992 - val_precision: 0.8413 - val_recall: 0.6625 - val_auc: 0.9309 Epoch 8/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0040 - tp: 92.4396 - fp: 11.4835 - tn: 93970.6374 - fn: 66.0110 - accuracy: 0.9992 - precision: 0.9188 - recall: 0.5617 - auc: 0.9298 - val_loss: 0.0036 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8529 - val_recall: 0.7250 - val_auc: 0.9309 Epoch 9/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 100.4396 - fp: 14.7802 - tn: 93963.6044 - fn: 61.7473 - accuracy: 0.9992 - precision: 0.8725 - recall: 0.6228 - auc: 0.9167 - val_loss: 0.0035 - val_tp: 58.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8406 - val_recall: 0.7250 - val_auc: 0.9247 Epoch 10/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0046 - tp: 94.8132 - fp: 12.7253 - tn: 93958.3187 - fn: 74.7143 - accuracy: 0.9991 - precision: 0.8935 - recall: 0.5710 - auc: 0.9196 - val_loss: 0.0035 - val_tp: 61.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8472 - val_recall: 0.7625 - val_auc: 0.9247 Epoch 11/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0047 - tp: 97.0440 - fp: 16.3956 - tn: 93958.9231 - fn: 68.2088 - accuracy: 0.9991 - precision: 0.8591 - recall: 0.5946 - auc: 0.9160 - val_loss: 0.0034 - val_tp: 58.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8529 - val_recall: 0.7250 - val_auc: 0.9247 Epoch 12/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0040 - tp: 101.0769 - fp: 13.1868 - tn: 93958.7143 - fn: 67.5934 - accuracy: 0.9992 - precision: 0.8799 - recall: 0.6127 - auc: 0.9202 - val_loss: 0.0034 - val_tp: 61.0000 - val_fp: 11.0000 - val_tn: 45478.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8472 - val_recall: 0.7625 - val_auc: 0.9247 Epoch 13/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0043 - tp: 98.0769 - fp: 16.0330 - tn: 93961.6703 - fn: 64.7912 - accuracy: 0.9992 - precision: 0.8536 - recall: 0.6112 - auc: 0.9154 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 9.0000 - val_tn: 45480.0000 - val_fn: 21.0000 - val_accuracy: 0.9993 - val_precision: 0.8676 - val_recall: 0.7375 - val_auc: 0.9247 Epoch 14/100 90/90 [==============================] - 1s 9ms/step - loss: 0.0050 - tp: 93.5495 - fp: 15.4286 - tn: 93961.4615 - fn: 70.1319 - accuracy: 0.9991 - precision: 0.8590 - recall: 0.5563 - auc: 0.8916 - val_loss: 0.0033 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247 Epoch 15/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0042 - tp: 90.2198 - fp: 15.4286 - tn: 93968.0989 - fn: 66.8242 - accuracy: 0.9992 - precision: 0.8524 - recall: 0.5813 - auc: 0.9270 - val_loss: 0.0033 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247 Epoch 16/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 96.4835 - fp: 14.7363 - tn: 93960.4396 - fn: 68.9121 - accuracy: 0.9991 - precision: 0.8754 - recall: 0.5727 - auc: 0.9218 - val_loss: 0.0033 - val_tp: 62.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8611 - val_recall: 0.7750 - val_auc: 0.9247 Epoch 17/100 90/90 [==============================] - 1s 7ms/step - loss: 0.0044 - tp: 98.7033 - fp: 16.4835 - tn: 93958.9231 - fn: 66.4615 - accuracy: 0.9991 - precision: 0.8674 - recall: 0.5985 - auc: 0.9108 - val_loss: 0.0032 - val_tp: 60.0000 - val_fp: 10.0000 - val_tn: 45479.0000 - val_fn: 20.0000 - val_accuracy: 0.9993 - val_precision: 0.8571 - val_recall: 0.7500 - val_auc: 0.9247 Restoring model weights from the end of the best epoch. Epoch 00017: early stopping
Periksa riwayat pelatihan
Di bagian ini, Anda akan membuat plot tentang akurasi dan kerugian model Anda pada set pelatihan dan validasi. Ini berguna untuk memeriksa overfitting, yang dapat Anda pelajari lebih lanjut di tutorial ini.
Selain itu, Anda dapat membuat plot ini untuk salah satu metrik yang Anda buat di atas. Negatif palsu disertakan sebagai contoh.
def plot_metrics(history):
metrics = ['loss', 'auc', 'precision', 'recall']
for n, metric in enumerate(metrics):
name = metric.replace("_"," ").capitalize()
plt.subplot(2,2,n+1)
plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
plt.plot(history.epoch, history.history['val_'+metric],
color=colors[0], linestyle="--", label='Val')
plt.xlabel('Epoch')
plt.ylabel(name)
if metric == 'loss':
plt.ylim([0, plt.ylim()[1]])
elif metric == 'auc':
plt.ylim([0.8,1])
else:
plt.ylim([0,1])
plt.legend()
plot_metrics(baseline_history)
Evaluasi metrik
Anda dapat menggunakan matriks kebingungan untuk meringkas label sebenarnya vs. yang diprediksi di mana sumbu X adalah label yang diprediksi dan sumbu Y adalah label sebenarnya.
train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
cm = confusion_matrix(labels, predictions > p)
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt="d")
plt.title('Confusion matrix @{:.2f}'.format(p))
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
print('Total Fraudulent Transactions: ', np.sum(cm[1]))
Evaluasi model Anda pada set data pengujian dan tampilkan hasilnya untuk metrik yang Anda buat di atas.
baseline_results = model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_baseline)
loss : 0.0034731989726424217 tp : 56.0 fp : 12.0 tn : 56855.0 fn : 39.0 accuracy : 0.9991046786308289 precision : 0.8235294222831726 recall : 0.5894736647605896 auc : 0.9418253898620605 Legitimate Transactions Detected (True Negatives): 56855 Legitimate Transactions Incorrectly Detected (False Positives): 12 Fraudulent Transactions Missed (False Negatives): 39 Fraudulent Transactions Detected (True Positives): 56 Total Fraudulent Transactions: 95
Jika model telah memprediksi semuanya dengan sempurna, ini akan menjadi matriks diagonal di mana nilai dari diagonal utama, yang menunjukkan prediksi yang salah, akan menjadi nol. Dalam kasus ini, matriks menunjukkan bahwa Anda memiliki relatif sedikit positif palsu, yang berarti bahwa ada relatif sedikit transaksi sah yang salah ditandai. Namun, Anda mungkin ingin memiliki lebih sedikit negatif palsu meskipun harus menambah jumlah positif palsu. Pertukaran ini mungkin lebih disukai karena negatif palsu akan memungkinkan transaksi penipuan terjadi, sedangkan positif palsu dapat menyebabkan email dikirim ke pelanggan untuk meminta mereka memverifikasi aktivitas kartu mereka.
Plot ROC
Sekarang plot ROC . Plot ini berguna karena sekilas menunjukkan kisaran performa yang dapat dicapai model hanya dengan menyetel ambang keluaran.
def plot_roc(name, labels, predictions, **kwargs):
fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)
plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
plt.xlabel('False positives [%]')
plt.ylabel('True positives [%]')
plt.xlim([-0.5,20])
plt.ylim([80,100.5])
plt.grid(True)
ax = plt.gca()
ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa4786ca748>
Sepertinya ketepatannya relatif tinggi, tetapi recall dan area di bawah kurva KOP (AUC) tidak setinggi yang Anda inginkan. Pengklasifikasi sering menghadapi tantangan saat mencoba memaksimalkan presisi dan perolehan, yang terutama terjadi saat bekerja dengan set data yang tidak seimbang. Penting untuk mempertimbangkan biaya dari berbagai jenis kesalahan dalam konteks masalah yang Anda pedulikan. Dalam contoh ini, negatif palsu (transaksi curang terlewat) mungkin memiliki biaya finansial, sedangkan positif palsu (transaksi salah ditandai sebagai penipuan) dapat menurunkan kebahagiaan pengguna.
Bobot kelas
Hitung bobot kelas
Sasarannya adalah untuk mengidentifikasi transaksi penipuan, tetapi Anda tidak memiliki terlalu banyak sampel positif untuk dikerjakan, jadi Anda ingin agar pengklasifikasi memberi bobot pada beberapa contoh yang tersedia. Anda dapat melakukan ini dengan meneruskan bobot Keras untuk setiap kelas melalui parameter. Ini akan menyebabkan model "lebih memperhatikan" contoh dari kelas yang kurang terwakili.
# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0
weight_for_1 = (1 / pos)*(total)/2.0
class_weight = {0: weight_for_0, 1: weight_for_1}
print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50 Weight for class 1: 289.44
Latih model dengan bobot kelas
Sekarang coba latih ulang dan evaluasi model dengan bobot kelas untuk melihat bagaimana hal itu memengaruhi prediksi.
weighted_model = make_model()
weighted_model.load_weights(initial_weights)
weighted_history = weighted_model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels),
# The class weights go here
class_weight=class_weight)
Epoch 1/100 90/90 [==============================] - 3s 15ms/step - loss: 3.7596 - tp: 56.9890 - fp: 76.4286 - tn: 150769.4286 - fn: 199.7253 - accuracy: 0.9983 - precision: 0.4816 - recall: 0.2595 - auc: 0.7156 - val_loss: 0.0096 - val_tp: 2.0000 - val_fp: 1.0000 - val_tn: 45488.0000 - val_fn: 78.0000 - val_accuracy: 0.9983 - val_precision: 0.6667 - val_recall: 0.0250 - val_auc: 0.8999 Epoch 2/100 90/90 [==============================] - 1s 7ms/step - loss: 1.5426 - tp: 57.2308 - fp: 247.2088 - tn: 93722.7582 - fn: 113.3736 - accuracy: 0.9964 - precision: 0.1813 - recall: 0.2976 - auc: 0.8189 - val_loss: 0.0089 - val_tp: 54.0000 - val_fp: 18.0000 - val_tn: 45471.0000 - val_fn: 26.0000 - val_accuracy: 0.9990 - val_precision: 0.7500 - val_recall: 0.6750 - val_auc: 0.9306 Epoch 3/100 90/90 [==============================] - 1s 7ms/step - loss: 0.8711 - tp: 95.8352 - fp: 494.1209 - tn: 93479.8681 - fn: 70.7473 - accuracy: 0.9943 - precision: 0.1692 - recall: 0.6059 - auc: 0.8912 - val_loss: 0.0121 - val_tp: 66.0000 - val_fp: 32.0000 - val_tn: 45457.0000 - val_fn: 14.0000 - val_accuracy: 0.9990 - val_precision: 0.6735 - val_recall: 0.8250 - val_auc: 0.9426 Epoch 4/100 90/90 [==============================] - 1s 7ms/step - loss: 0.6835 - tp: 108.5165 - fp: 794.1648 - tn: 93183.5055 - fn: 54.3846 - accuracy: 0.9912 - precision: 0.1191 - recall: 0.6530 - auc: 0.8987 - val_loss: 0.0163 - val_tp: 67.0000 - val_fp: 54.0000 - val_tn: 45435.0000 - val_fn: 13.0000 - val_accuracy: 0.9985 - val_precision: 0.5537 - val_recall: 0.8375 - val_auc: 0.9556 Epoch 5/100 90/90 [==============================] - 1s 7ms/step - loss: 0.4713 - tp: 126.3626 - fp: 1149.3407 - tn: 92827.5275 - fn: 37.3407 - accuracy: 0.9878 - precision: 0.0992 - recall: 0.7828 - auc: 0.9329 - val_loss: 0.0214 - val_tp: 67.0000 - val_fp: 95.0000 - val_tn: 45394.0000 - val_fn: 13.0000 - val_accuracy: 0.9976 - val_precision: 0.4136 - val_recall: 0.8375 - val_auc: 0.9588 Epoch 6/100 90/90 [==============================] - 1s 7ms/step - loss: 0.4194 - tp: 125.7912 - fp: 1550.7253 - tn: 92430.8791 - fn: 33.1758 - accuracy: 0.9837 - precision: 0.0769 - recall: 0.7990 - auc: 0.9373 - val_loss: 0.0270 - val_tp: 67.0000 - val_fp: 147.0000 - val_tn: 45342.0000 - val_fn: 13.0000 - val_accuracy: 0.9965 - val_precision: 0.3131 - val_recall: 0.8375 - val_auc: 0.9626 Epoch 7/100 90/90 [==============================] - 1s 7ms/step - loss: 0.4226 - tp: 127.0659 - fp: 2000.6374 - tn: 91978.2857 - fn: 34.5824 - accuracy: 0.9788 - precision: 0.0567 - recall: 0.7672 - auc: 0.9351 - val_loss: 0.0348 - val_tp: 67.0000 - val_fp: 224.0000 - val_tn: 45265.0000 - val_fn: 13.0000 - val_accuracy: 0.9948 - val_precision: 0.2302 - val_recall: 0.8375 - val_auc: 0.9656 Epoch 8/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2990 - tp: 137.7473 - fp: 2564.6154 - tn: 91411.7363 - fn: 26.4725 - accuracy: 0.9729 - precision: 0.0528 - recall: 0.8483 - auc: 0.9609 - val_loss: 0.0457 - val_tp: 69.0000 - val_fp: 406.0000 - val_tn: 45083.0000 - val_fn: 11.0000 - val_accuracy: 0.9908 - val_precision: 0.1453 - val_recall: 0.8625 - val_auc: 0.9691 Epoch 9/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3165 - tp: 125.0330 - fp: 3192.7473 - tn: 90795.8462 - fn: 26.9451 - accuracy: 0.9662 - precision: 0.0375 - recall: 0.8237 - auc: 0.9518 - val_loss: 0.0568 - val_tp: 69.0000 - val_fp: 654.0000 - val_tn: 44835.0000 - val_fn: 11.0000 - val_accuracy: 0.9854 - val_precision: 0.0954 - val_recall: 0.8625 - val_auc: 0.9703 Epoch 10/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2917 - tp: 139.2527 - fp: 3709.6154 - tn: 90270.3626 - fn: 21.3407 - accuracy: 0.9606 - precision: 0.0356 - recall: 0.8714 - auc: 0.9546 - val_loss: 0.0640 - val_tp: 70.0000 - val_fp: 755.0000 - val_tn: 44734.0000 - val_fn: 10.0000 - val_accuracy: 0.9832 - val_precision: 0.0848 - val_recall: 0.8750 - val_auc: 0.9713 Epoch 11/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3109 - tp: 155.7363 - fp: 3838.3736 - tn: 90123.9670 - fn: 22.4945 - accuracy: 0.9590 - precision: 0.0415 - recall: 0.8730 - auc: 0.9555 - val_loss: 0.0703 - val_tp: 71.0000 - val_fp: 835.0000 - val_tn: 44654.0000 - val_fn: 9.0000 - val_accuracy: 0.9815 - val_precision: 0.0784 - val_recall: 0.8875 - val_auc: 0.9719 Epoch 12/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3007 - tp: 134.9121 - fp: 4045.0549 - tn: 89938.1758 - fn: 22.4286 - accuracy: 0.9566 - precision: 0.0320 - recall: 0.8712 - auc: 0.9494 - val_loss: 0.0756 - val_tp: 72.0000 - val_fp: 910.0000 - val_tn: 44579.0000 - val_fn: 8.0000 - val_accuracy: 0.9799 - val_precision: 0.0733 - val_recall: 0.9000 - val_auc: 0.9716 Epoch 13/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2083 - tp: 143.2308 - fp: 4041.6593 - tn: 89938.9670 - fn: 16.7143 - accuracy: 0.9567 - precision: 0.0360 - recall: 0.9154 - auc: 0.9734 - val_loss: 0.0765 - val_tp: 72.0000 - val_fp: 916.0000 - val_tn: 44573.0000 - val_fn: 8.0000 - val_accuracy: 0.9797 - val_precision: 0.0729 - val_recall: 0.9000 - val_auc: 0.9726 Epoch 14/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2448 - tp: 143.6703 - fp: 4119.2967 - tn: 89857.6154 - fn: 19.9890 - accuracy: 0.9562 - precision: 0.0341 - recall: 0.8944 - auc: 0.9655 - val_loss: 0.0811 - val_tp: 72.0000 - val_fp: 992.0000 - val_tn: 44497.0000 - val_fn: 8.0000 - val_accuracy: 0.9781 - val_precision: 0.0677 - val_recall: 0.9000 - val_auc: 0.9724 Epoch 15/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2362 - tp: 141.0769 - fp: 4205.1429 - tn: 89776.9670 - fn: 17.3846 - accuracy: 0.9545 - precision: 0.0316 - recall: 0.8889 - auc: 0.9665 - val_loss: 0.0835 - val_tp: 72.0000 - val_fp: 1019.0000 - val_tn: 44470.0000 - val_fn: 8.0000 - val_accuracy: 0.9775 - val_precision: 0.0660 - val_recall: 0.9000 - val_auc: 0.9729 Epoch 16/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2100 - tp: 135.4945 - fp: 4278.8242 - tn: 89707.5495 - fn: 18.7033 - accuracy: 0.9542 - precision: 0.0289 - recall: 0.8980 - auc: 0.9717 - val_loss: 0.0922 - val_tp: 72.0000 - val_fp: 1117.0000 - val_tn: 44372.0000 - val_fn: 8.0000 - val_accuracy: 0.9753 - val_precision: 0.0606 - val_recall: 0.9000 - val_auc: 0.9728 Epoch 17/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2061 - tp: 147.9451 - fp: 4488.5604 - tn: 89486.7912 - fn: 17.2747 - accuracy: 0.9519 - precision: 0.0328 - recall: 0.9026 - auc: 0.9754 - val_loss: 0.0911 - val_tp: 72.0000 - val_fp: 1104.0000 - val_tn: 44385.0000 - val_fn: 8.0000 - val_accuracy: 0.9756 - val_precision: 0.0612 - val_recall: 0.9000 - val_auc: 0.9730 Epoch 18/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3032 - tp: 143.0989 - fp: 4367.9890 - tn: 89610.3956 - fn: 19.0879 - accuracy: 0.9529 - precision: 0.0312 - recall: 0.8782 - auc: 0.9486 - val_loss: 0.0878 - val_tp: 72.0000 - val_fp: 1037.0000 - val_tn: 44452.0000 - val_fn: 8.0000 - val_accuracy: 0.9771 - val_precision: 0.0649 - val_recall: 0.9000 - val_auc: 0.9761 Epoch 19/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2713 - tp: 148.5165 - fp: 4079.8462 - tn: 89894.0110 - fn: 18.1978 - accuracy: 0.9565 - precision: 0.0361 - recall: 0.8861 - auc: 0.9635 - val_loss: 0.0868 - val_tp: 72.0000 - val_fp: 1011.0000 - val_tn: 44478.0000 - val_fn: 8.0000 - val_accuracy: 0.9776 - val_precision: 0.0665 - val_recall: 0.9000 - val_auc: 0.9762 Epoch 20/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2623 - tp: 144.5055 - fp: 3991.8681 - tn: 89984.6484 - fn: 19.5495 - accuracy: 0.9575 - precision: 0.0332 - recall: 0.8837 - auc: 0.9587 - val_loss: 0.0888 - val_tp: 72.0000 - val_fp: 1030.0000 - val_tn: 44459.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0653 - val_recall: 0.9000 - val_auc: 0.9761 Epoch 21/100 90/90 [==============================] - 1s 7ms/step - loss: 0.3103 - tp: 136.6154 - fp: 3966.2527 - tn: 90015.5934 - fn: 22.1099 - accuracy: 0.9577 - precision: 0.0309 - recall: 0.8330 - auc: 0.9373 - val_loss: 0.0886 - val_tp: 72.0000 - val_fp: 1010.0000 - val_tn: 44479.0000 - val_fn: 8.0000 - val_accuracy: 0.9777 - val_precision: 0.0665 - val_recall: 0.9000 - val_auc: 0.9764 Epoch 22/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2215 - tp: 141.6484 - fp: 3872.5055 - tn: 90108.2527 - fn: 18.1648 - accuracy: 0.9584 - precision: 0.0347 - recall: 0.8935 - auc: 0.9698 - val_loss: 0.0862 - val_tp: 72.0000 - val_fp: 972.0000 - val_tn: 44517.0000 - val_fn: 8.0000 - val_accuracy: 0.9785 - val_precision: 0.0690 - val_recall: 0.9000 - val_auc: 0.9776 Epoch 23/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2214 - tp: 141.6593 - fp: 3877.3626 - tn: 90102.6593 - fn: 18.8901 - accuracy: 0.9585 - precision: 0.0354 - recall: 0.8872 - auc: 0.9685 - val_loss: 0.0842 - val_tp: 72.0000 - val_fp: 938.0000 - val_tn: 44551.0000 - val_fn: 8.0000 - val_accuracy: 0.9792 - val_precision: 0.0713 - val_recall: 0.9000 - val_auc: 0.9777 Epoch 24/100 90/90 [==============================] - 1s 7ms/step - loss: 0.1873 - tp: 138.3956 - fp: 3647.7473 - tn: 90337.8681 - fn: 16.5604 - accuracy: 0.9611 - precision: 0.0340 - recall: 0.9088 - auc: 0.9743 - val_loss: 0.0843 - val_tp: 73.0000 - val_fp: 938.0000 - val_tn: 44551.0000 - val_fn: 7.0000 - val_accuracy: 0.9793 - val_precision: 0.0722 - val_recall: 0.9125 - val_auc: 0.9779 Epoch 25/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2038 - tp: 140.3407 - fp: 3553.5824 - tn: 90428.6813 - fn: 17.9670 - accuracy: 0.9620 - precision: 0.0383 - recall: 0.8946 - auc: 0.9739 - val_loss: 0.0854 - val_tp: 73.0000 - val_fp: 942.0000 - val_tn: 44547.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0719 - val_recall: 0.9125 - val_auc: 0.9777 Epoch 26/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2003 - tp: 151.4725 - fp: 3495.4725 - tn: 90479.3516 - fn: 14.2747 - accuracy: 0.9625 - precision: 0.0406 - recall: 0.9196 - auc: 0.9732 - val_loss: 0.0819 - val_tp: 73.0000 - val_fp: 895.0000 - val_tn: 44594.0000 - val_fn: 7.0000 - val_accuracy: 0.9802 - val_precision: 0.0754 - val_recall: 0.9125 - val_auc: 0.9778 Epoch 27/100 90/90 [==============================] - 1s 9ms/step - loss: 0.2111 - tp: 138.2857 - fp: 3438.2088 - tn: 90547.7253 - fn: 16.3516 - accuracy: 0.9635 - precision: 0.0382 - recall: 0.9021 - auc: 0.9690 - val_loss: 0.0865 - val_tp: 73.0000 - val_fp: 940.0000 - val_tn: 44549.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0721 - val_recall: 0.9125 - val_auc: 0.9775 Epoch 28/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2115 - tp: 146.6813 - fp: 3796.8022 - tn: 90178.5604 - fn: 18.5275 - accuracy: 0.9595 - precision: 0.0357 - recall: 0.8799 - auc: 0.9740 - val_loss: 0.0914 - val_tp: 73.0000 - val_fp: 996.0000 - val_tn: 44493.0000 - val_fn: 7.0000 - val_accuracy: 0.9780 - val_precision: 0.0683 - val_recall: 0.9125 - val_auc: 0.9774 Epoch 29/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2085 - tp: 142.0220 - fp: 3757.9341 - tn: 90222.3956 - fn: 18.2198 - accuracy: 0.9597 - precision: 0.0368 - recall: 0.8941 - auc: 0.9738 - val_loss: 0.0890 - val_tp: 73.0000 - val_fp: 962.0000 - val_tn: 44527.0000 - val_fn: 7.0000 - val_accuracy: 0.9787 - val_precision: 0.0705 - val_recall: 0.9125 - val_auc: 0.9775 Epoch 30/100 90/90 [==============================] - 1s 7ms/step - loss: 0.1559 - tp: 153.3077 - fp: 3571.6044 - tn: 90404.6264 - fn: 11.0330 - accuracy: 0.9617 - precision: 0.0420 - recall: 0.9426 - auc: 0.9811 - val_loss: 0.0852 - val_tp: 73.0000 - val_fp: 901.0000 - val_tn: 44588.0000 - val_fn: 7.0000 - val_accuracy: 0.9801 - val_precision: 0.0749 - val_recall: 0.9125 - val_auc: 0.9778 Epoch 31/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2635 - tp: 142.1209 - fp: 3672.6264 - tn: 90304.4725 - fn: 21.3516 - accuracy: 0.9613 - precision: 0.0376 - recall: 0.8706 - auc: 0.9641 - val_loss: 0.0930 - val_tp: 73.0000 - val_fp: 1011.0000 - val_tn: 44478.0000 - val_fn: 7.0000 - val_accuracy: 0.9777 - val_precision: 0.0673 - val_recall: 0.9125 - val_auc: 0.9773 Epoch 32/100 90/90 [==============================] - 1s 7ms/step - loss: 0.1858 - tp: 155.9121 - fp: 3632.0440 - tn: 90338.0769 - fn: 14.5385 - accuracy: 0.9609 - precision: 0.0410 - recall: 0.9196 - auc: 0.9791 - val_loss: 0.0897 - val_tp: 73.0000 - val_fp: 972.0000 - val_tn: 44517.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0699 - val_recall: 0.9125 - val_auc: 0.9774 Epoch 33/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2525 - tp: 152.4945 - fp: 3743.6264 - tn: 90226.1209 - fn: 18.3297 - accuracy: 0.9599 - precision: 0.0398 - recall: 0.8872 - auc: 0.9646 - val_loss: 0.0878 - val_tp: 73.0000 - val_fp: 930.0000 - val_tn: 44559.0000 - val_fn: 7.0000 - val_accuracy: 0.9794 - val_precision: 0.0728 - val_recall: 0.9125 - val_auc: 0.9774 Epoch 34/100 90/90 [==============================] - 1s 7ms/step - loss: 0.2286 - tp: 142.4615 - fp: 3510.0659 - tn: 90469.1209 - fn: 18.9231 - accuracy: 0.9626 - precision: 0.0379 - recall: 0.8608 - auc: 0.9694 - val_loss: 0.0839 - val_tp: 73.0000 - val_fp: 889.0000 - val_tn: 44600.0000 - val_fn: 7.0000 - val_accuracy: 0.9803 - val_precision: 0.0759 - val_recall: 0.9125 - val_auc: 0.9773 Restoring model weights from the end of the best epoch. Epoch 00034: early stopping
Periksa riwayat pelatihan
plot_metrics(weighted_history)
Evaluasi metrik
train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_weighted)
loss : 0.08238842338323593 tp : 89.0 fp : 1186.0 tn : 55681.0 fn : 6.0 accuracy : 0.9790737628936768 precision : 0.06980392336845398 recall : 0.9368420839309692 auc : 0.98465895652771 Legitimate Transactions Detected (True Negatives): 55681 Legitimate Transactions Incorrectly Detected (False Positives): 1186 Fraudulent Transactions Missed (False Negatives): 6 Fraudulent Transactions Detected (True Positives): 89 Total Fraudulent Transactions: 95
Di sini Anda dapat melihat bahwa dengan bobot kelas, akurasi dan presisi lebih rendah karena terdapat lebih banyak positif palsu, tetapi sebaliknya recall dan AUC lebih tinggi karena model juga menemukan lebih banyak positif benar. Meskipun memiliki akurasi yang lebih rendah, model ini memiliki daya ingat yang lebih tinggi (dan mengidentifikasi lebih banyak transaksi curang). Tentu saja, ada biaya untuk kedua jenis kesalahan tersebut (Anda juga tidak ingin mengganggu pengguna dengan menandai terlalu banyak transaksi yang sah sebagai penipuan). Pertimbangkan dengan cermat kompromi antara berbagai jenis kesalahan ini untuk aplikasi Anda.
Plot ROC
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa4fc124e48>
Oversampling
Terlalu banyak kelas minoritas
Pendekatan terkait akan mengambil sampel ulang kumpulan data dengan mengambil sampel berlebihan dari kelas minoritas.
pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]
pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
Menggunakan NumPy
Anda dapat menyeimbangkan kumpulan data secara manual dengan memilih jumlah indeks acak yang tepat dari contoh positif:
ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))
res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]
res_pos_features.shape
(181959, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)
order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]
resampled_features.shape
(363918, 29)
Menggunakan tf.data
Jika Anda menggunakan tf.data
, cara termudah untuk menghasilkan contoh yang seimbang adalah memulai dengan kumpulan data positive
dan negative
, dan menggabungkannya. Lihat panduan tf.data untuk contoh lainnya.
BUFFER_SIZE = 100000
def make_ds(features, labels):
ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
ds = ds.shuffle(BUFFER_SIZE).repeat()
return ds
pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)
Setiap set data menyediakan pasangan (feature, label)
:
for features, label in pos_ds.take(1):
print("Features:\n", features.numpy())
print()
print("Label: ", label.numpy())
Features: [-2.19450702e+00 1.14157653e+00 -1.53868568e+00 -3.35790031e-01 -8.45956971e-01 -1.57966490e+00 -1.68436379e+00 2.25259298e-01 2.62655847e-01 -3.51815449e+00 2.45653756e+00 -3.62123216e+00 -1.02812838e+00 -5.00000000e+00 1.79960408e+00 -3.72675038e+00 -5.00000000e+00 -2.22257320e+00 -1.24126989e-02 -9.14798839e-01 7.33686754e-01 -9.34731229e-02 -1.74683429e+00 4.44394133e-01 -3.98093307e-02 -1.99985035e+00 -2.27011783e+00 1.83260825e-03 5.69563522e-01] Label: 1
Gabungkan keduanya bersama-sama menggunakan experimental.sample_from_datasets
:
resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
print(label.numpy().mean())
0.5
Untuk menggunakan kumpulan data ini, Anda memerlukan jumlah langkah per epoch.
Definisi "epoch" dalam hal ini kurang jelas. Katakanlah itu jumlah kelompok yang diperlukan untuk melihat setiap contoh negatif satu kali:
resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0
Latih data yang diambil sampelnya terlalu banyak
Sekarang coba latih model dengan kumpulan data sampel ulang alih-alih menggunakan bobot kelas untuk melihat bagaimana metode ini dibandingkan.
resampled_model = make_model()
resampled_model.load_weights(initial_weights)
# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1]
output_layer.bias.assign([0])
val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2)
resampled_history = resampled_model.fit(
resampled_ds,
epochs=EPOCHS,
steps_per_epoch=resampled_steps_per_epoch,
callbacks=[early_stopping],
validation_data=val_ds)
Epoch 1/100 278/278 [==============================] - 8s 24ms/step - loss: 0.8118 - tp: 98159.0036 - fp: 34959.3333 - tn: 165642.9498 - fn: 44913.3728 - accuracy: 0.7582 - precision: 0.6573 - recall: 0.5935 - auc: 0.7814 - val_loss: 0.1785 - val_tp: 70.0000 - val_fp: 1177.0000 - val_tn: 44312.0000 - val_fn: 10.0000 - val_accuracy: 0.9740 - val_precision: 0.0561 - val_recall: 0.8750 - val_auc: 0.9722 Epoch 2/100 278/278 [==============================] - 6s 21ms/step - loss: 0.2142 - tp: 126477.2616 - fp: 8422.5305 - tn: 134479.1362 - fn: 17333.7312 - accuracy: 0.9075 - precision: 0.9341 - recall: 0.8774 - auc: 0.9705 - val_loss: 0.1007 - val_tp: 71.0000 - val_fp: 919.0000 - val_tn: 44570.0000 - val_fn: 9.0000 - val_accuracy: 0.9796 - val_precision: 0.0717 - val_recall: 0.8875 - val_auc: 0.9718 Epoch 3/100 278/278 [==============================] - 6s 21ms/step - loss: 0.1652 - tp: 128160.7455 - fp: 6029.4875 - tn: 137362.8602 - fn: 15159.5663 - accuracy: 0.9255 - precision: 0.9544 - recall: 0.8936 - auc: 0.9831 - val_loss: 0.0775 - val_tp: 71.0000 - val_fp: 802.0000 - val_tn: 44687.0000 - val_fn: 9.0000 - val_accuracy: 0.9822 - val_precision: 0.0813 - val_recall: 0.8875 - val_auc: 0.9727 Epoch 4/100 278/278 [==============================] - 6s 21ms/step - loss: 0.1451 - tp: 129069.5376 - fp: 5223.8925 - tn: 138326.3763 - fn: 14092.8530 - accuracy: 0.9322 - precision: 0.9611 - recall: 0.9007 - auc: 0.9872 - val_loss: 0.0676 - val_tp: 71.0000 - val_fp: 756.0000 - val_tn: 44733.0000 - val_fn: 9.0000 - val_accuracy: 0.9832 - val_precision: 0.0859 - val_recall: 0.8875 - val_auc: 0.9766 Epoch 5/100 278/278 [==============================] - 6s 22ms/step - loss: 0.1328 - tp: 130130.0502 - fp: 5189.8996 - tn: 138364.5878 - fn: 13028.1219 - accuracy: 0.9360 - precision: 0.9611 - recall: 0.9085 - auc: 0.9895 - val_loss: 0.0616 - val_tp: 72.0000 - val_fp: 763.0000 - val_tn: 44726.0000 - val_fn: 8.0000 - val_accuracy: 0.9831 - val_precision: 0.0862 - val_recall: 0.9000 - val_auc: 0.9748 Epoch 6/100 278/278 [==============================] - 6s 22ms/step - loss: 0.1230 - tp: 131300.7706 - fp: 4925.0179 - tn: 138645.5125 - fn: 11841.3584 - accuracy: 0.9413 - precision: 0.9635 - recall: 0.9170 - auc: 0.9912 - val_loss: 0.0566 - val_tp: 72.0000 - val_fp: 754.0000 - val_tn: 44735.0000 - val_fn: 8.0000 - val_accuracy: 0.9833 - val_precision: 0.0872 - val_recall: 0.9000 - val_auc: 0.9759 Epoch 7/100 278/278 [==============================] - 6s 22ms/step - loss: 0.1150 - tp: 132477.4875 - fp: 4822.2509 - tn: 138449.6989 - fn: 10963.2222 - accuracy: 0.9445 - precision: 0.9647 - recall: 0.9228 - auc: 0.9925 - val_loss: 0.0516 - val_tp: 72.0000 - val_fp: 711.0000 - val_tn: 44778.0000 - val_fn: 8.0000 - val_accuracy: 0.9842 - val_precision: 0.0920 - val_recall: 0.9000 - val_auc: 0.9777 Epoch 8/100 278/278 [==============================] - 6s 23ms/step - loss: 0.1069 - tp: 132886.1828 - fp: 4656.9570 - tn: 138848.8029 - fn: 10320.7168 - accuracy: 0.9474 - precision: 0.9660 - recall: 0.9275 - auc: 0.9936 - val_loss: 0.0462 - val_tp: 71.0000 - val_fp: 667.0000 - val_tn: 44822.0000 - val_fn: 9.0000 - val_accuracy: 0.9852 - val_precision: 0.0962 - val_recall: 0.8875 - val_auc: 0.9695 Epoch 9/100 278/278 [==============================] - 6s 22ms/step - loss: 0.1006 - tp: 133579.5950 - fp: 4598.8602 - tn: 138524.1039 - fn: 10010.1004 - accuracy: 0.9489 - precision: 0.9668 - recall: 0.9298 - auc: 0.9944 - val_loss: 0.0419 - val_tp: 71.0000 - val_fp: 639.0000 - val_tn: 44850.0000 - val_fn: 9.0000 - val_accuracy: 0.9858 - val_precision: 0.1000 - val_recall: 0.8875 - val_auc: 0.9703 Epoch 10/100 278/278 [==============================] - 6s 22ms/step - loss: 0.0955 - tp: 134632.2903 - fp: 4540.9355 - tn: 138577.5520 - fn: 8961.8817 - accuracy: 0.9521 - precision: 0.9676 - recall: 0.9358 - auc: 0.9950 - val_loss: 0.0396 - val_tp: 71.0000 - val_fp: 634.0000 - val_tn: 44855.0000 - val_fn: 9.0000 - val_accuracy: 0.9859 - val_precision: 0.1007 - val_recall: 0.8875 - val_auc: 0.9664 Epoch 11/100 278/278 [==============================] - 6s 22ms/step - loss: 0.0908 - tp: 140213.1900 - fp: 5993.7849 - tn: 136973.6380 - fn: 3532.0466 - accuracy: 0.9663 - precision: 0.9588 - recall: 0.9747 - auc: 0.9955 - val_loss: 0.0359 - val_tp: 71.0000 - val_fp: 596.0000 - val_tn: 44893.0000 - val_fn: 9.0000 - val_accuracy: 0.9867 - val_precision: 0.1064 - val_recall: 0.8875 - val_auc: 0.9669 Epoch 12/100 278/278 [==============================] - 6s 21ms/step - loss: 0.0857 - tp: 140261.6918 - fp: 5979.1470 - tn: 137657.2294 - fn: 2814.5914 - accuracy: 0.9691 - precision: 0.9589 - recall: 0.9800 - auc: 0.9959 - val_loss: 0.0337 - val_tp: 71.0000 - val_fp: 580.0000 - val_tn: 44909.0000 - val_fn: 9.0000 - val_accuracy: 0.9871 - val_precision: 0.1091 - val_recall: 0.8875 - val_auc: 0.9677 Epoch 13/100 278/278 [==============================] - 6s 21ms/step - loss: 0.0818 - tp: 140796.7133 - fp: 5945.5305 - tn: 137458.5520 - fn: 2511.8638 - accuracy: 0.9704 - precision: 0.9596 - recall: 0.9822 - auc: 0.9962 - val_loss: 0.0318 - val_tp: 71.0000 - val_fp: 558.0000 - val_tn: 44931.0000 - val_fn: 9.0000 - val_accuracy: 0.9876 - val_precision: 0.1129 - val_recall: 0.8875 - val_auc: 0.9682 Epoch 14/100 278/278 [==============================] - 7s 24ms/step - loss: 0.0793 - tp: 140997.9176 - fp: 6076.8746 - tn: 137562.6846 - fn: 2075.1828 - accuracy: 0.9714 - precision: 0.9586 - recall: 0.9853 - auc: 0.9964 - val_loss: 0.0303 - val_tp: 71.0000 - val_fp: 555.0000 - val_tn: 44934.0000 - val_fn: 9.0000 - val_accuracy: 0.9876 - val_precision: 0.1134 - val_recall: 0.8875 - val_auc: 0.9687 Epoch 15/100 278/278 [==============================] - 6s 22ms/step - loss: 0.0759 - tp: 141966.7312 - fp: 6100.1147 - tn: 136957.0108 - fn: 1688.8029 - accuracy: 0.9729 - precision: 0.9589 - recall: 0.9883 - auc: 0.9966 - val_loss: 0.0292 - val_tp: 71.0000 - val_fp: 541.0000 - val_tn: 44948.0000 - val_fn: 9.0000 - val_accuracy: 0.9879 - val_precision: 0.1160 - val_recall: 0.8875 - val_auc: 0.9692 Epoch 16/100 278/278 [==============================] - 6s 21ms/step - loss: 0.0727 - tp: 141879.4731 - fp: 6020.2007 - tn: 137272.9140 - fn: 1540.0717 - accuracy: 0.9736 - precision: 0.9594 - recall: 0.9891 - auc: 0.9968 - val_loss: 0.0276 - val_tp: 71.0000 - val_fp: 504.0000 - val_tn: 44985.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 0.1235 - val_recall: 0.8875 - val_auc: 0.9697 Epoch 17/100 278/278 [==============================] - 6s 21ms/step - loss: 0.0704 - tp: 142140.7563 - fp: 6019.2258 - tn: 137225.3978 - fn: 1327.2796 - accuracy: 0.9745 - precision: 0.9597 - recall: 0.9907 - auc: 0.9969 - val_loss: 0.0269 - val_tp: 71.0000 - val_fp: 495.0000 - val_tn: 44994.0000 - val_fn: 9.0000 - val_accuracy: 0.9889 - val_precision: 0.1254 - val_recall: 0.8875 - val_auc: 0.9699 Restoring model weights from the end of the best epoch. Epoch 00017: early stopping
Jika proses pelatihan mempertimbangkan seluruh kumpulan data pada setiap pembaruan gradien, pengambilan sampel berlebihan ini pada dasarnya akan identik dengan pembobotan kelas.
Namun saat melatih model secara batch, seperti yang Anda lakukan di sini, data sampel berlebih memberikan sinyal gradien yang lebih halus: Alih-alih setiap contoh positif ditampilkan dalam satu batch dengan bobot besar, mereka ditampilkan dalam banyak batch berbeda setiap kali dengan berat kecil.
Sinyal gradien yang lebih halus ini memudahkan untuk melatih model.
Periksa riwayat pelatihan
Perhatikan bahwa distribusi metrik akan berbeda di sini, karena data pelatihan memiliki distribusi yang sangat berbeda dari data validasi dan pengujian.
plot_metrics(resampled_history)
Latih ulang
Karena pelatihan lebih mudah pada data yang seimbang, prosedur pelatihan di atas mungkin cepat selesai.
Jadi hentikan waktu untuk memberikan callbacks.EarlyStopping
Kontrol yang lebih baik dari EarlyStopping kapan harus menghentikan pelatihan.
resampled_model = make_model()
resampled_model.load_weights(initial_weights)
# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1]
output_layer.bias.assign([0])
resampled_history = resampled_model.fit(
resampled_ds,
# These are not real epochs
steps_per_epoch=20,
epochs=10*EPOCHS,
callbacks=[early_stopping],
validation_data=(val_ds))
Epoch 1/1000 20/20 [==============================] - 3s 59ms/step - loss: 1.6583 - tp: 2292.0000 - fp: 4139.7619 - tn: 52605.9524 - fn: 8961.7619 - accuracy: 0.8197 - precision: 0.3317 - recall: 0.1958 - auc: 0.7857 - val_loss: 0.5873 - val_tp: 9.0000 - val_fp: 13094.0000 - val_tn: 32395.0000 - val_fn: 71.0000 - val_accuracy: 0.7111 - val_precision: 6.8687e-04 - val_recall: 0.1125 - val_auc: 0.2616 Epoch 2/1000 20/20 [==============================] - 1s 27ms/step - loss: 1.0305 - tp: 4655.6667 - fp: 3775.2857 - tn: 7528.4286 - fn: 6471.0952 - accuracy: 0.5318 - precision: 0.5347 - recall: 0.3960 - auc: 0.4700 - val_loss: 0.5752 - val_tp: 45.0000 - val_fp: 12493.0000 - val_tn: 32996.0000 - val_fn: 35.0000 - val_accuracy: 0.7251 - val_precision: 0.0036 - val_recall: 0.5625 - val_auc: 0.7427 Epoch 3/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.6911 - tp: 7407.9048 - fp: 3609.9048 - tn: 7705.3333 - fn: 3707.3333 - accuracy: 0.6637 - precision: 0.6637 - recall: 0.6487 - auc: 0.6971 - val_loss: 0.5367 - val_tp: 73.0000 - val_fp: 10392.0000 - val_tn: 35097.0000 - val_fn: 7.0000 - val_accuracy: 0.7718 - val_precision: 0.0070 - val_recall: 0.9125 - val_auc: 0.9205 Epoch 4/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.5257 - tp: 8758.5238 - fp: 3223.3333 - tn: 8031.5714 - fn: 2417.0476 - accuracy: 0.7432 - precision: 0.7273 - recall: 0.7793 - auc: 0.8200 - val_loss: 0.4801 - val_tp: 73.0000 - val_fp: 7752.0000 - val_tn: 37737.0000 - val_fn: 7.0000 - val_accuracy: 0.8297 - val_precision: 0.0093 - val_recall: 0.9125 - val_auc: 0.9436 Epoch 5/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.4499 - tp: 9241.0952 - fp: 2761.1905 - tn: 8463.3333 - fn: 1964.8571 - accuracy: 0.7862 - precision: 0.7667 - recall: 0.8219 - auc: 0.8704 - val_loss: 0.4242 - val_tp: 72.0000 - val_fp: 5662.0000 - val_tn: 39827.0000 - val_fn: 8.0000 - val_accuracy: 0.8756 - val_precision: 0.0126 - val_recall: 0.9000 - val_auc: 0.9524 Epoch 6/1000 20/20 [==============================] - 1s 28ms/step - loss: 0.4029 - tp: 9361.0000 - fp: 2311.8571 - tn: 8961.6190 - fn: 1796.0000 - accuracy: 0.8140 - precision: 0.7965 - recall: 0.8389 - auc: 0.8968 - val_loss: 0.3750 - val_tp: 71.0000 - val_fp: 4105.0000 - val_tn: 41384.0000 - val_fn: 9.0000 - val_accuracy: 0.9097 - val_precision: 0.0170 - val_recall: 0.8875 - val_auc: 0.9583 Epoch 7/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.3601 - tp: 9456.8571 - fp: 1812.6667 - tn: 9376.6190 - fn: 1784.3333 - accuracy: 0.8386 - precision: 0.8376 - recall: 0.8413 - auc: 0.9181 - val_loss: 0.3324 - val_tp: 70.0000 - val_fp: 3047.0000 - val_tn: 42442.0000 - val_fn: 10.0000 - val_accuracy: 0.9329 - val_precision: 0.0225 - val_recall: 0.8750 - val_auc: 0.9626 Epoch 8/1000 20/20 [==============================] - 1s 28ms/step - loss: 0.3373 - tp: 9506.3810 - fp: 1623.5714 - tn: 9568.3333 - fn: 1732.1905 - accuracy: 0.8490 - precision: 0.8520 - recall: 0.8453 - auc: 0.9266 - val_loss: 0.2959 - val_tp: 69.0000 - val_fp: 2358.0000 - val_tn: 43131.0000 - val_fn: 11.0000 - val_accuracy: 0.9480 - val_precision: 0.0284 - val_recall: 0.8625 - val_auc: 0.9660 Epoch 9/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.3120 - tp: 9635.0000 - fp: 1347.7619 - tn: 9797.4286 - fn: 1650.2857 - accuracy: 0.8662 - precision: 0.8776 - recall: 0.8534 - auc: 0.9366 - val_loss: 0.2670 - val_tp: 69.0000 - val_fp: 1943.0000 - val_tn: 43546.0000 - val_fn: 11.0000 - val_accuracy: 0.9571 - val_precision: 0.0343 - val_recall: 0.8625 - val_auc: 0.9689 Epoch 10/1000 20/20 [==============================] - 1s 28ms/step - loss: 0.2923 - tp: 9632.5714 - fp: 1221.3333 - tn: 9978.5714 - fn: 1598.0000 - accuracy: 0.8725 - precision: 0.8864 - recall: 0.8559 - auc: 0.9446 - val_loss: 0.2417 - val_tp: 69.0000 - val_fp: 1609.0000 - val_tn: 43880.0000 - val_fn: 11.0000 - val_accuracy: 0.9644 - val_precision: 0.0411 - val_recall: 0.8625 - val_auc: 0.9706 Epoch 11/1000 20/20 [==============================] - 1s 29ms/step - loss: 0.2765 - tp: 9632.7619 - fp: 1069.8095 - tn: 10155.4762 - fn: 1572.4286 - accuracy: 0.8806 - precision: 0.8976 - recall: 0.8584 - auc: 0.9505 - val_loss: 0.2211 - val_tp: 69.0000 - val_fp: 1424.0000 - val_tn: 44065.0000 - val_fn: 11.0000 - val_accuracy: 0.9685 - val_precision: 0.0462 - val_recall: 0.8625 - val_auc: 0.9716 Epoch 12/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2624 - tp: 9721.6190 - fp: 949.2857 - tn: 10243.6190 - fn: 1515.9524 - accuracy: 0.8901 - precision: 0.9102 - recall: 0.8658 - auc: 0.9556 - val_loss: 0.2038 - val_tp: 69.0000 - val_fp: 1299.0000 - val_tn: 44190.0000 - val_fn: 11.0000 - val_accuracy: 0.9713 - val_precision: 0.0504 - val_recall: 0.8625 - val_auc: 0.9722 Epoch 13/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2491 - tp: 9862.2381 - fp: 851.1905 - tn: 10195.6667 - fn: 1521.3810 - accuracy: 0.8933 - precision: 0.9207 - recall: 0.8656 - auc: 0.9597 - val_loss: 0.1904 - val_tp: 70.0000 - val_fp: 1230.0000 - val_tn: 44259.0000 - val_fn: 10.0000 - val_accuracy: 0.9728 - val_precision: 0.0538 - val_recall: 0.8750 - val_auc: 0.9726 Epoch 14/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2384 - tp: 9752.6190 - fp: 806.5714 - tn: 10438.3333 - fn: 1432.9524 - accuracy: 0.9006 - precision: 0.9238 - recall: 0.8724 - auc: 0.9637 - val_loss: 0.1781 - val_tp: 70.0000 - val_fp: 1186.0000 - val_tn: 44303.0000 - val_fn: 10.0000 - val_accuracy: 0.9738 - val_precision: 0.0557 - val_recall: 0.8750 - val_auc: 0.9724 Epoch 15/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.2332 - tp: 9727.1905 - fp: 787.1905 - tn: 10508.9524 - fn: 1407.1429 - accuracy: 0.9024 - precision: 0.9255 - recall: 0.8737 - auc: 0.9651 - val_loss: 0.1664 - val_tp: 70.0000 - val_fp: 1130.0000 - val_tn: 44359.0000 - val_fn: 10.0000 - val_accuracy: 0.9750 - val_precision: 0.0583 - val_recall: 0.8750 - val_auc: 0.9728 Epoch 16/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2242 - tp: 9812.1905 - fp: 731.9524 - tn: 10481.3333 - fn: 1405.0000 - accuracy: 0.9048 - precision: 0.9310 - recall: 0.8745 - auc: 0.9677 - val_loss: 0.1561 - val_tp: 70.0000 - val_fp: 1085.0000 - val_tn: 44404.0000 - val_fn: 10.0000 - val_accuracy: 0.9760 - val_precision: 0.0606 - val_recall: 0.8750 - val_auc: 0.9725 Epoch 17/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2197 - tp: 9864.8571 - fp: 691.0000 - tn: 10487.2381 - fn: 1387.3810 - accuracy: 0.9072 - precision: 0.9350 - recall: 0.8766 - auc: 0.9690 - val_loss: 0.1475 - val_tp: 70.0000 - val_fp: 1042.0000 - val_tn: 44447.0000 - val_fn: 10.0000 - val_accuracy: 0.9769 - val_precision: 0.0629 - val_recall: 0.8750 - val_auc: 0.9724 Epoch 18/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.2110 - tp: 9788.9524 - fp: 651.8571 - tn: 10585.9524 - fn: 1403.7143 - accuracy: 0.9087 - precision: 0.9392 - recall: 0.8742 - auc: 0.9716 - val_loss: 0.1401 - val_tp: 70.0000 - val_fp: 1020.0000 - val_tn: 44469.0000 - val_fn: 10.0000 - val_accuracy: 0.9774 - val_precision: 0.0642 - val_recall: 0.8750 - val_auc: 0.9718 Epoch 19/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.2035 - tp: 9779.1905 - fp: 619.4286 - tn: 10699.0952 - fn: 1332.7619 - accuracy: 0.9117 - precision: 0.9402 - recall: 0.8784 - auc: 0.9733 - val_loss: 0.1340 - val_tp: 70.0000 - val_fp: 1010.0000 - val_tn: 44479.0000 - val_fn: 10.0000 - val_accuracy: 0.9776 - val_precision: 0.0648 - val_recall: 0.8750 - val_auc: 0.9721 Epoch 20/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.1993 - tp: 9782.3333 - fp: 592.7619 - tn: 10738.1905 - fn: 1317.1905 - accuracy: 0.9156 - precision: 0.9430 - recall: 0.8823 - auc: 0.9748 - val_loss: 0.1282 - val_tp: 70.0000 - val_fp: 999.0000 - val_tn: 44490.0000 - val_fn: 10.0000 - val_accuracy: 0.9779 - val_precision: 0.0655 - val_recall: 0.8750 - val_auc: 0.9721 Epoch 21/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.1995 - tp: 9862.0476 - fp: 625.4762 - tn: 10605.0952 - fn: 1337.8571 - accuracy: 0.9119 - precision: 0.9394 - recall: 0.8805 - auc: 0.9747 - val_loss: 0.1228 - val_tp: 70.0000 - val_fp: 965.0000 - val_tn: 44524.0000 - val_fn: 10.0000 - val_accuracy: 0.9786 - val_precision: 0.0676 - val_recall: 0.8750 - val_auc: 0.9718 Epoch 22/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.1936 - tp: 9804.0476 - fp: 582.0952 - tn: 10755.0000 - fn: 1289.3333 - accuracy: 0.9168 - precision: 0.9433 - recall: 0.8842 - auc: 0.9764 - val_loss: 0.1179 - val_tp: 70.0000 - val_fp: 944.0000 - val_tn: 44545.0000 - val_fn: 10.0000 - val_accuracy: 0.9791 - val_precision: 0.0690 - val_recall: 0.8750 - val_auc: 0.9713 Epoch 23/1000 20/20 [==============================] - 1s 31ms/step - loss: 0.1889 - tp: 9952.6667 - fp: 552.3333 - tn: 10589.3810 - fn: 1336.0952 - accuracy: 0.9155 - precision: 0.9472 - recall: 0.8812 - auc: 0.9774 - val_loss: 0.1146 - val_tp: 70.0000 - val_fp: 942.0000 - val_tn: 44547.0000 - val_fn: 10.0000 - val_accuracy: 0.9791 - val_precision: 0.0692 - val_recall: 0.8750 - val_auc: 0.9715 Epoch 24/1000 20/20 [==============================] - 1s 33ms/step - loss: 0.1828 - tp: 9880.4286 - fp: 540.2381 - tn: 10753.0000 - fn: 1256.8095 - accuracy: 0.9207 - precision: 0.9483 - recall: 0.8883 - auc: 0.9787 - val_loss: 0.1116 - val_tp: 71.0000 - val_fp: 961.0000 - val_tn: 44528.0000 - val_fn: 9.0000 - val_accuracy: 0.9787 - val_precision: 0.0688 - val_recall: 0.8875 - val_auc: 0.9712 Epoch 25/1000 20/20 [==============================] - 1s 32ms/step - loss: 0.1829 - tp: 9871.3333 - fp: 555.4286 - tn: 10747.9048 - fn: 1255.8095 - accuracy: 0.9200 - precision: 0.9470 - recall: 0.8879 - auc: 0.9791 - val_loss: 0.1082 - val_tp: 71.0000 - val_fp: 953.0000 - val_tn: 44536.0000 - val_fn: 9.0000 - val_accuracy: 0.9789 - val_precision: 0.0693 - val_recall: 0.8875 - val_auc: 0.9713 Restoring model weights from the end of the best epoch. Epoch 00025: early stopping
Periksa kembali riwayat pelatihan
plot_metrics(resampled_history)
Evaluasi metrik
train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_resampled)
loss : 0.16524264216423035 tp : 91.0 fp : 1376.0 tn : 55491.0 fn : 4.0 accuracy : 0.9757733345031738 precision : 0.06203135475516319 recall : 0.9578947424888611 auc : 0.9829339385032654 Legitimate Transactions Detected (True Negatives): 55491 Legitimate Transactions Incorrectly Detected (False Positives): 1376 Fraudulent Transactions Missed (False Negatives): 4 Fraudulent Transactions Detected (True Positives): 91 Total Fraudulent Transactions: 95
Plot ROC
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fa0682848d0>
Menerapkan tutorial ini untuk masalah Anda
Klasifikasi data yang tidak seimbang merupakan tugas yang sulit karena hanya ada sedikit sampel untuk dipelajari. Anda harus selalu memulai dengan data terlebih dahulu dan melakukan yang terbaik untuk mengumpulkan sampel sebanyak mungkin dan memberikan pemikiran yang substansial tentang fitur apa yang mungkin relevan sehingga model tersebut dapat memaksimalkan kelas minoritas Anda. Pada titik tertentu model Anda mungkin kesulitan untuk meningkatkan dan memberikan hasil yang Anda inginkan, jadi penting untuk mengingat konteks masalah Anda dan trade off antara berbagai jenis kesalahan.