불균형 데이터 분류

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 소스 Download notebook

이 튜토리얼에서는 한 클래스의 예시의 수가 다른 클래스보다 훨씬 많은 불균형 데이터세트를 분류하는 방법을 소개합니다. Kaggle에서 호스팅 되는 신용 카드 부정 행위 탐지 데이터세트를 사용하여 작업할 것입니다. 총 284,807건의 거래에서 492건의 부정거래만 적발하는 것이 목적입니다. Keras를 사용하여 모델 및 클래스 가중치를 정의하여 불균형 데이터로부터 모델을 학습할 수 있도록 할 것입니다.

이 튜토리얼에서는 다음의 완전한 코드가 포함되어있습니다.:

  • Pandas를 사용하여 CSV 파일 로드.
  • 학습, 검증 및 테스트세트 작성.
  • Keras를 사용하여 모델을 정의하고 학습시키기(클래스 가중치 설정 포함).
  • 다양한 측정 기준(정밀도 및 재현 율 포함)을 사용하여 모델을 평가한다.
  • 불균형 데이터를 처리하기 위한 다음과 같은 기술을 사용해보십시오:
    • 클래스 가중치
    • 오버샘플링

설정

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

데이터 처리 및 탐색

Kaggle 신용 카드 부정 행위 데이터 세트

Pandas는 구조화된 데이터를 로드하고 작업하는데 유용한 유틸리티가 많이 있는 Python 라이브러리로서 CSV를 데이터 프레임으로 다운로드 하는데 사용할 수 있다.

참고: 이 데이터세트는 큰 데이터 마이닝 및 부정 행위 감지에 대한 Worldline과 ULB의 Machine Learning Group (Université Libre de Bruxelles)의 연구 협업을 통해 수집 및 분석 되었다. 관련 주제에 대한 현재 및 과거의 프로젝트에 대한 자세한 내용은 여기 and the page of the DefeatFraud에서 확인할 수 있으며 DefeatFraud 프로젝트 페이지도 참조하십시오

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

클래스 레이블 불균형 조사

데이터세트 불균형을 살펴보겠습니다.:

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)


이것은 양성 샘플의 작은 부분을 보여줍니다.

데이터 정리, 분할 및 정규화

원시 데이터에는 몇 가지 문제가 있습니다. TimeAmount 열은 너무 가변적이기 때문에 직접적으로 사용할 수 없습니다. 우선 Time 열을 삭제한 뒤에 (의미가 명확하지 않아서) amount 열의 로그를 가져와서 범위를 줄입니다.

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)

데이터 세트를 학습, 검증 및 테스트 세트로 분할합니다. 검증 세트는 모델 피팅 중에 손실 및 메트릭을 평가하는 데 사용되지만 모델이 이 데이터에 적합하지 않습니다. 테스트 세트는 훈련 단계에서 완전히 사용되지 않으며 모델이 새 데이터로 얼마나 잘 일반화되는지 평가하기 위해 마지막에만 사용됩니다. 이는 훈련 데이터 부족으로 인하여 오버피팅 이 중요한 문제인 데이터 세트에서 특히 더 중요합니다.

# Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

sklearn StandardScaler를 사용하여 입력 기능을 정규화 합니다. 이것은 평균은 0으로, 표준 편차는 1로 설정합니다.

참고: StandardScaler는 오직 모델이 validation 또는 test set를 peeking 하지는 않았는지 확인하기 위해 train_feature를 사용할 때 적합합니다.

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

주의: 모델을 배포하려면 전처리 계산을 유지하는 것이 중요합니다. 레이어로 구현하고 내보내기 전에 모델에 연결하는 것이 가장 쉬운 방법입니다.

데이터 분포 살펴보기

다음으로 몇 가지 기능에 대한 긍정 및 부정 예제의 분포를 비교하십시오. 이 때 스스로에게 물어볼 좋은 질문은 다음과 같습니다.:

  • 이러한 분포가 의미가 있습니까?
    • 예. 이미 입력을 정규화했으며 대부분 +/- 2 범위에 밀집되어 있습니다.
  • 분포의 차이를 볼 수 있습니까?
    • 예. 긍정적인 예는 그렇지 않은 것 보다 훨씬 더 높은 극단적인 값을 포함합니다.
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning

png

png

모델 및 메트릭 정의

촘촘하게 연결된 히든 레이어, 과적합을 줄이기 위한 drop out 레이어, 거래 사기 가능성을 반환하는 출력 sigmoid 레이어로 간단한 신경망을 생성하는 함수를 정의합니다. :

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics=METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

유용한 메트릭 이해

위에서 정의한 몇 가지 지표는 성능을 평가할 때 도움이 될 모델에 의해 계산될 수 있다는 점에 유의하십시오.

  • 거짓 음성 그리고 거짓 양성은 잘못 분류된 샘플입니다.
  • 음성 그리고 양성은 제대로 분류된 샘플입니다.
  • 정확도 는 올바르게 분류된 예제의 비율입니다. > $\frac{\text{true samples} }{\text{total samples} }$
  • 정밀도 는 올바르게 분류된 예측 긍정 비율입니다. > $\frac{\text{true positives} }{\text{true positives + false positives} }$
  • 재현 율 은 올바르게 분류된 실제 긍정 비율입니다. > $\frac{\text{true positives} }{\text{true positives + false negatives} }$
  • AUC 는 수신자 조작 특성 곡선 아래 영역(ROC-AUC)을 나타냅니다. 이 메트릭은 분류기가 무작위 음성 샘플보다 무작위 양성 샘플의 순위를 매길 확률과 동일합니다.

참고: 정확도는 이 작업에 유용한 측정 항목이 아닙니다. 항상 False를 예측해야 이 작업에서 99.8% 이상의 정확도를 얻을 수 있습니다.

Read more:

기준 모델

모델 구축

이제 이전에 정의한 함수를 통해서 모델을 만들고 학습시키십시오. 모델의 크기가 기본 배치 크기인 2048보다 큰 배치 크기를 사용하여야 적합한 것을 유의하십시오. 이는 각 배치에서 몇 개의 양성 샘플을 포함할 수 있는 적절한 기회를 확보하는데 있어서 중요하다.

참고: 이 모델은 클래스의 불균형을 잘 다루지 못합니다. 이를 이 튜토리얼의 뒷부분에서 개선하게 될 겁니다.

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________

Test run the model:

model.predict(train_features[:10])
array([[0.7381869 ],
       [0.3955272 ],
       [0.5119114 ],
       [0.72830456],
       [0.6611035 ],
       [0.78933924],
       [0.69018126],
       [0.82324433],
       [0.6972425 ],
       [0.7572208 ]], dtype=float32)

선택사항: 올바른 초기 바이어스를 설정합니다.

이러한 초기 추측은 좋지 못합니다. 데이터 세트가 불균형 하다는 것을 알고 있습니다. 그렇다면 이를 반영하도록 출력 계층의 바이어스를 설정합니다. (참조: 신경망 훈련을 위한 레시피: "init well"). 이것은 초기 수렴에 도움이 될 수 있습니다.

기본 바이어스 초기화를 사용하면 손실은 약 math.log(2) = 0.69314

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 1.4751

설정할 올바른 바이어스는 다음에서 파생 가능합니다.:

$$ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) $$
$$ b_0 = -log_e(1/p_0 - 1) $$
$$ b_0 = log_e(pos/neg)$$
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

이를 초기 바이어스로 설정하면 모델이 훨씬 더 합리적인 초기 추측을 제공합니다.

가까워 야합니다.: pos/total = 0.0018

model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
array([[0.00107778],
       [0.00055683],
       [0.00089524],
       [0.00157561],
       [0.00017924],
       [0.00187467],
       [0.00146499],
       [0.00451898],
       [0.00164567],
       [0.00056281]], dtype=float32)

이 초기화를 통해서 초기 손실은 대략 다음과 같아야합니다.:

$$-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317$$
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0128

이 초기 손실은 단순한 상태의 초기화에서 발생했을 때 보다 약 50배 적습니다.

이런 식으로 모델은 긍정적인 예시의 가능성이 낮다는 것을 배우면서 처음 몇 epoch를 보낼 필요는 없습니다. 이것은 또한 훈련 중 plot의 손실을 더 쉽게 읽어낼 수 있게 해줍니다.

초기 가중치 체크 포인트

다양한 훈련 실행을 더욱 비교 가능하도록 하고 싶다면 초기 모델의 가중치를 체크 포인트 파일에 보관하고 훈련 전에 각 모델에 로드 하십시오

initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)

바이어스 수정이 도움이 되는지 확인

계속 진행하기 전에 조심스러운 바이어스 초기화가 실제로 도움이 되었는지 빠르게 확인하십시오

조심스럽게 초기화를 한 것과 사용하지 않은 것의 20 epoch 동안 모델을 훈련하고 손실을 비교합니다.:

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale to show the wide range of values.
  plt.semilogy(history.epoch, history.history['loss'],
               color=colors[n], label='Train '+label)
  plt.semilogy(history.epoch, history.history['val_loss'],
          color=colors[n], label='Val '+label,
          linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')

  plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

png

위의 그림은 이를 명확하게 보여줍니다. 유효성 검사 손실 측면에서 이 문제에 대해 조심스러운 초기화는 명확한 이점을 제공합니다.

모델 훈련

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 3s 16ms/step - loss: 0.0133 - tp: 73.4945 - fp: 48.3956 - tn: 139411.4945 - fn: 176.1868 - accuracy: 0.9985 - precision: 0.6155 - recall: 0.3264 - auc: 0.8215 - val_loss: 0.0055 - val_tp: 19.0000 - val_fp: 6.0000 - val_tn: 45488.0000 - val_fn: 56.0000 - val_accuracy: 0.9986 - val_precision: 0.7600 - val_recall: 0.2533 - val_auc: 0.9122
Epoch 2/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0079 - tp: 52.0879 - fp: 17.9011 - tn: 93956.8791 - fn: 113.7033 - accuracy: 0.9985 - precision: 0.6710 - recall: 0.2698 - auc: 0.8640 - val_loss: 0.0045 - val_tp: 36.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 39.0000 - val_accuracy: 0.9990 - val_precision: 0.8372 - val_recall: 0.4800 - val_auc: 0.9262
Epoch 3/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0057 - tp: 91.4066 - fp: 18.0000 - tn: 93955.5714 - fn: 75.5934 - accuracy: 0.9990 - precision: 0.8365 - recall: 0.5759 - auc: 0.9133 - val_loss: 0.0042 - val_tp: 46.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 29.0000 - val_accuracy: 0.9992 - val_precision: 0.8679 - val_recall: 0.6133 - val_auc: 0.9264
Epoch 4/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0059 - tp: 86.1209 - fp: 16.9780 - tn: 93958.6044 - fn: 78.8681 - accuracy: 0.9990 - precision: 0.8275 - recall: 0.4938 - auc: 0.9053 - val_loss: 0.0040 - val_tp: 45.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 30.0000 - val_accuracy: 0.9992 - val_precision: 0.8654 - val_recall: 0.6000 - val_auc: 0.9331
Epoch 5/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0058 - tp: 82.1868 - fp: 18.6923 - tn: 93960.9451 - fn: 78.7473 - accuracy: 0.9989 - precision: 0.8064 - recall: 0.5085 - auc: 0.9182 - val_loss: 0.0037 - val_tp: 43.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 32.0000 - val_accuracy: 0.9991 - val_precision: 0.8600 - val_recall: 0.5733 - val_auc: 0.9330
Epoch 6/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0050 - tp: 97.7692 - fp: 17.8022 - tn: 93946.7033 - fn: 78.2967 - accuracy: 0.9990 - precision: 0.8587 - recall: 0.5545 - auc: 0.9302 - val_loss: 0.0035 - val_tp: 51.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8793 - val_recall: 0.6800 - val_auc: 0.9264
Epoch 7/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0048 - tp: 100.8681 - fp: 16.0769 - tn: 93951.5714 - fn: 72.0549 - accuracy: 0.9991 - precision: 0.8456 - recall: 0.5922 - auc: 0.9372 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.8833 - val_recall: 0.7067 - val_auc: 0.9264
Epoch 8/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0051 - tp: 93.4286 - fp: 18.9121 - tn: 93950.3626 - fn: 77.8681 - accuracy: 0.9990 - precision: 0.8271 - recall: 0.5493 - auc: 0.9041 - val_loss: 0.0033 - val_tp: 51.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8793 - val_recall: 0.6800 - val_auc: 0.9197
Epoch 9/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0053 - tp: 98.7253 - fp: 21.7912 - tn: 93946.9011 - fn: 73.1538 - accuracy: 0.9989 - precision: 0.7869 - recall: 0.5651 - auc: 0.9359 - val_loss: 0.0032 - val_tp: 52.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 23.0000 - val_accuracy: 0.9993 - val_precision: 0.8814 - val_recall: 0.6933 - val_auc: 0.9197
Epoch 10/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0046 - tp: 104.3187 - fp: 16.3407 - tn: 93954.1978 - fn: 65.7143 - accuracy: 0.9991 - precision: 0.8842 - recall: 0.6126 - auc: 0.9172 - val_loss: 0.0031 - val_tp: 50.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8772 - val_recall: 0.6667 - val_auc: 0.9263
Epoch 11/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0046 - tp: 97.5604 - fp: 19.2308 - tn: 93954.6264 - fn: 69.1538 - accuracy: 0.9991 - precision: 0.8426 - recall: 0.5893 - auc: 0.9118 - val_loss: 0.0030 - val_tp: 53.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.8833 - val_recall: 0.7067 - val_auc: 0.9197
Epoch 12/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0043 - tp: 106.5385 - fp: 16.3516 - tn: 93945.9670 - fn: 71.7143 - accuracy: 0.9991 - precision: 0.8746 - recall: 0.6103 - auc: 0.9413 - val_loss: 0.0031 - val_tp: 57.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.8906 - val_recall: 0.7600 - val_auc: 0.9198
Epoch 13/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0047 - tp: 94.0000 - fp: 18.6593 - tn: 93951.6923 - fn: 76.2198 - accuracy: 0.9990 - precision: 0.8202 - recall: 0.5400 - auc: 0.9143 - val_loss: 0.0030 - val_tp: 59.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8939 - val_recall: 0.7867 - val_auc: 0.9197
Epoch 14/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0043 - tp: 105.5824 - fp: 21.2747 - tn: 93952.0330 - fn: 61.6813 - accuracy: 0.9992 - precision: 0.8249 - recall: 0.6606 - auc: 0.9360 - val_loss: 0.0029 - val_tp: 58.0000 - val_fp: 7.0000 - val_tn: 45487.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.8923 - val_recall: 0.7733 - val_auc: 0.9263
Restoring model weights from the end of the best epoch.
Epoch 00014: early stopping

학습 이력 확인

이 섹션에서는 훈련 및 검증 세트에 대한 모델의 정확도와 손실에 대한 plot을 생성합니다. 이는 과적합을 확인하는데 유용하며 이 튜토리얼에서 자세한 내용을 확인할 수 있습니다.

추가적으로, 위에서 만든 모든 메트릭에 대해 이러한 plot을 생성할 수 있습니다. 거짓 음성이 포함되는 경우가 예시입니다.

def plot_metrics(history):
  metrics = ['loss', 'auc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

png

참고: 검증 곡선은 일반적으로 훈련 곡선보다 성능이 좋습니다. 이는 주로 모델을 평가할 때 drop out 레이어가 활성화 되지 않았기 때문에 발생합니다.

메트릭 평가

혼동 행렬 을 사용하여 X축이 예측 레이블이고 Y축이 실제 레이블인 실제 레이블과 예측 레이블을 요약할 수 있습니다.

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

테스트 데이터 세트에서 모델을 평가하고 위에서 만든 측정 항목의 결과를 표시합니다.

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.004399415571242571
tp :  50.0
fp :  7.0
tn :  56868.0
fn :  37.0
accuracy :  0.9992275834083557
precision :  0.8771929740905762
recall :  0.5747126340866089
auc :  0.8962478637695312

Legitimate Transactions Detected (True Negatives):  56868
Legitimate Transactions Incorrectly Detected (False Positives):  7
Fraudulent Transactions Missed (False Negatives):  37
Fraudulent Transactions Detected (True Positives):  50
Total Fraudulent Transactions:  87

png

만약 모델이 모든 것을 완벽하게 예측했다면 이것은 잘못된 예측을 나타내는 주 대각선의 값이 0이 되는 대각행렬 이 됩니다. 이러한 경우에 매트릭스가 잘못 탐지한 경우가 상대적으로 적다는 것을 보여줍니다. 즉 잘못 플래그가 지정된 합법적인 거래가 상대적으로 적은 것을 의미합니다. 그러나 거짓 양성 수를 늘릴 때 드는 비용에도 불구하고 더 적은 수의 거짓 음성을 원할 수 있습니다. 거짓 음성 판정이 부정 거래를 통과할 수 있는 반면, 거짓 긍정 판정이 고객에게 이메일을 보내 카드 활동을 확인하도록 요청할 수 있기 때문에 이러한 거래 중단이 더 바람 직 할 수 있습니다.

ROC 플로팅

이제 ROC을 플로팅 하십시오. 이 그래프는 출력 임계값을 조정하기만 해도 모델이 도달할 수 있는 성능 범위를 한눈에 보여주기 때문에 유용합니다.

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7ff30c294588>

png

비교적 정밀도가 높은 것 같지만 회수율과 ROC 곡선(AUC) 밑 면적이 마음에 들 만큼 높지 않습니다. 분류자는 정밀도와 리콜을 모두 최대화 하려고 할 때 종종 도전해야 할 문제에 직면하는데, 이는 불균형 데이터세트로 작업할 떄 특히 그러합니다. 당신이 신경쓰는 문제의 맥락에서 다른 유형의 오류의 비용을 고려하는 것이 중요합니다. 이 예시에서 거짓음성(부정 거래를 놓친 경우)은 금전적인 비용을 초래하지만 , 거짓 양성(거래가 사기 행위로 잘못 표시됨)은 사용자들의 만족도를 감소시킬 수 있습니다.

클래스 가중치

클래스 가중치 계산

목표는 부정 거래를 식별하는 것이지만, 여러분은 작업할 수 있는 긍정적인 샘플이 많지 않기 깨문에 분류자가 이용할 수 있는 몇 가지 예에 가중치를 두고자 할 것입니다. 매개 변수를 통해 각 클래스에 대한 Keras 가중치를 전달한다면 이 과정을 할 수 있습니다. 이로 인해 모델이 덜 표현된 클래스의 예에 "더 많은 주의를 기울이십시오"라고 할 수도 있습니다.

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

클래스 가중치로 모델 교육

이제 해당 모델이 예측에 어떤 영향을 미치는지 확인하기 위하여 클래스 가중치로 모델을 재 교육하고 평가해 보십시오.

참고: class_weights 를 사용하면 손실 범위가 바뀝니다. 이는 최적기에 따라 학습의 안정성에 영향을 미칠 수 있습니다. 단계 크기가 그라데이션의 크기에 따라 달라지는 optimizers.SGD 와 같은 최적화 도구는 실패할 수 있습니다. 여기서 사용되는 최적화기인 optimizers.Adam 은 스케일링 변화에 영향을 받지 않습니다. 또한 가중치 때문에 전체 손실은 두 모델 간에 비교할 수 없습니다.

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 3s 17ms/step - loss: 2.5743 - tp: 80.4396 - fp: 114.1538 - tn: 150727.6374 - fn: 180.3407 - accuracy: 0.9982 - precision: 0.4724 - recall: 0.3182 - auc: 0.8169 - val_loss: 0.0069 - val_tp: 37.0000 - val_fp: 10.0000 - val_tn: 45484.0000 - val_fn: 38.0000 - val_accuracy: 0.9989 - val_precision: 0.7872 - val_recall: 0.4933 - val_auc: 0.9371
Epoch 2/100
90/90 [==============================] - 1s 8ms/step - loss: 1.0780 - tp: 88.4615 - fp: 289.8681 - tn: 93678.7473 - fn: 83.4945 - accuracy: 0.9962 - precision: 0.2386 - recall: 0.5085 - auc: 0.8830 - val_loss: 0.0100 - val_tp: 61.0000 - val_fp: 28.0000 - val_tn: 45466.0000 - val_fn: 14.0000 - val_accuracy: 0.9991 - val_precision: 0.6854 - val_recall: 0.8133 - val_auc: 0.9560
Epoch 3/100
90/90 [==============================] - 1s 8ms/step - loss: 0.7321 - tp: 122.3956 - fp: 630.8242 - tn: 93335.6813 - fn: 51.6703 - accuracy: 0.9929 - precision: 0.1632 - recall: 0.6732 - auc: 0.9112 - val_loss: 0.0158 - val_tp: 64.0000 - val_fp: 94.0000 - val_tn: 45400.0000 - val_fn: 11.0000 - val_accuracy: 0.9977 - val_precision: 0.4051 - val_recall: 0.8533 - val_auc: 0.9546
Epoch 4/100
90/90 [==============================] - 1s 8ms/step - loss: 0.5424 - tp: 133.3516 - fp: 973.8901 - tn: 92994.5604 - fn: 38.7692 - accuracy: 0.9896 - precision: 0.1345 - recall: 0.7821 - auc: 0.9303 - val_loss: 0.0227 - val_tp: 64.0000 - val_fp: 188.0000 - val_tn: 45306.0000 - val_fn: 11.0000 - val_accuracy: 0.9956 - val_precision: 0.2540 - val_recall: 0.8533 - val_auc: 0.9552
Epoch 5/100
90/90 [==============================] - 1s 8ms/step - loss: 0.5608 - tp: 138.5165 - fp: 1400.8901 - tn: 92566.2308 - fn: 34.9341 - accuracy: 0.9849 - precision: 0.0887 - recall: 0.7605 - auc: 0.9121 - val_loss: 0.0301 - val_tp: 64.0000 - val_fp: 306.0000 - val_tn: 45188.0000 - val_fn: 11.0000 - val_accuracy: 0.9930 - val_precision: 0.1730 - val_recall: 0.8533 - val_auc: 0.9594
Epoch 6/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3400 - tp: 131.6923 - fp: 1650.0549 - tn: 92328.0000 - fn: 30.8242 - accuracy: 0.9826 - precision: 0.0741 - recall: 0.8159 - auc: 0.9563 - val_loss: 0.0390 - val_tp: 64.0000 - val_fp: 408.0000 - val_tn: 45086.0000 - val_fn: 11.0000 - val_accuracy: 0.9908 - val_precision: 0.1356 - val_recall: 0.8533 - val_auc: 0.9673
Epoch 7/100
90/90 [==============================] - 1s 8ms/step - loss: 0.4131 - tp: 142.9231 - fp: 2112.7802 - tn: 91857.1319 - fn: 27.7363 - accuracy: 0.9776 - precision: 0.0647 - recall: 0.8312 - auc: 0.9405 - val_loss: 0.0476 - val_tp: 64.0000 - val_fp: 542.0000 - val_tn: 44952.0000 - val_fn: 11.0000 - val_accuracy: 0.9879 - val_precision: 0.1056 - val_recall: 0.8533 - val_auc: 0.9684
Epoch 8/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3310 - tp: 150.6374 - fp: 2337.9780 - tn: 91629.2198 - fn: 22.7363 - accuracy: 0.9750 - precision: 0.0643 - recall: 0.8748 - auc: 0.9549 - val_loss: 0.0547 - val_tp: 64.0000 - val_fp: 629.0000 - val_tn: 44865.0000 - val_fn: 11.0000 - val_accuracy: 0.9860 - val_precision: 0.0924 - val_recall: 0.8533 - val_auc: 0.9671
Epoch 9/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2765 - tp: 145.7582 - fp: 2451.6154 - tn: 91526.0330 - fn: 17.1648 - accuracy: 0.9737 - precision: 0.0554 - recall: 0.8918 - auc: 0.9560 - val_loss: 0.0591 - val_tp: 64.0000 - val_fp: 679.0000 - val_tn: 44815.0000 - val_fn: 11.0000 - val_accuracy: 0.9849 - val_precision: 0.0861 - val_recall: 0.8533 - val_auc: 0.9673
Epoch 10/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2848 - tp: 149.7802 - fp: 2739.0220 - tn: 91229.2418 - fn: 22.5275 - accuracy: 0.9701 - precision: 0.0531 - recall: 0.8746 - auc: 0.9605 - val_loss: 0.0681 - val_tp: 65.0000 - val_fp: 751.0000 - val_tn: 44743.0000 - val_fn: 10.0000 - val_accuracy: 0.9833 - val_precision: 0.0797 - val_recall: 0.8667 - val_auc: 0.9686
Epoch 11/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2472 - tp: 143.2637 - fp: 2914.8571 - tn: 91063.0989 - fn: 19.3516 - accuracy: 0.9688 - precision: 0.0451 - recall: 0.8980 - auc: 0.9667 - val_loss: 0.0778 - val_tp: 65.0000 - val_fp: 860.0000 - val_tn: 44634.0000 - val_fn: 10.0000 - val_accuracy: 0.9809 - val_precision: 0.0703 - val_recall: 0.8667 - val_auc: 0.9719
Epoch 12/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2669 - tp: 145.6813 - fp: 3247.7033 - tn: 90726.3187 - fn: 20.8681 - accuracy: 0.9658 - precision: 0.0433 - recall: 0.8868 - auc: 0.9611 - val_loss: 0.0849 - val_tp: 67.0000 - val_fp: 929.0000 - val_tn: 44565.0000 - val_fn: 8.0000 - val_accuracy: 0.9794 - val_precision: 0.0673 - val_recall: 0.8933 - val_auc: 0.9716
Epoch 13/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2545 - tp: 146.3626 - fp: 3248.9780 - tn: 90727.5824 - fn: 17.6484 - accuracy: 0.9652 - precision: 0.0424 - recall: 0.8949 - auc: 0.9621 - val_loss: 0.0858 - val_tp: 67.0000 - val_fp: 937.0000 - val_tn: 44557.0000 - val_fn: 8.0000 - val_accuracy: 0.9793 - val_precision: 0.0667 - val_recall: 0.8933 - val_auc: 0.9726
Epoch 14/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2694 - tp: 152.2857 - fp: 3368.4286 - tn: 90601.1099 - fn: 18.7473 - accuracy: 0.9641 - precision: 0.0422 - recall: 0.8713 - auc: 0.9678 - val_loss: 0.0855 - val_tp: 67.0000 - val_fp: 939.0000 - val_tn: 44555.0000 - val_fn: 8.0000 - val_accuracy: 0.9792 - val_precision: 0.0666 - val_recall: 0.8933 - val_auc: 0.9728
Epoch 15/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3239 - tp: 149.8571 - fp: 3446.2747 - tn: 90519.8132 - fn: 24.6264 - accuracy: 0.9632 - precision: 0.0438 - recall: 0.8489 - auc: 0.9602 - val_loss: 0.0902 - val_tp: 67.0000 - val_fp: 977.0000 - val_tn: 44517.0000 - val_fn: 8.0000 - val_accuracy: 0.9784 - val_precision: 0.0642 - val_recall: 0.8933 - val_auc: 0.9745
Epoch 16/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1902 - tp: 149.7033 - fp: 3309.5934 - tn: 90665.5495 - fn: 15.7253 - accuracy: 0.9645 - precision: 0.0416 - recall: 0.9105 - auc: 0.9638 - val_loss: 0.0893 - val_tp: 67.0000 - val_fp: 964.0000 - val_tn: 44530.0000 - val_fn: 8.0000 - val_accuracy: 0.9787 - val_precision: 0.0650 - val_recall: 0.8933 - val_auc: 0.9782
Epoch 17/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2527 - tp: 152.0769 - fp: 3314.5714 - tn: 90657.9560 - fn: 15.9670 - accuracy: 0.9639 - precision: 0.0436 - recall: 0.9053 - auc: 0.9628 - val_loss: 0.0844 - val_tp: 67.0000 - val_fp: 918.0000 - val_tn: 44576.0000 - val_fn: 8.0000 - val_accuracy: 0.9797 - val_precision: 0.0680 - val_recall: 0.8933 - val_auc: 0.9759
Epoch 18/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2879 - tp: 142.3407 - fp: 3300.9890 - tn: 90676.6923 - fn: 20.5495 - accuracy: 0.9647 - precision: 0.0407 - recall: 0.8608 - auc: 0.9590 - val_loss: 0.0904 - val_tp: 68.0000 - val_fp: 976.0000 - val_tn: 44518.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0651 - val_recall: 0.9067 - val_auc: 0.9760
Epoch 19/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3168 - tp: 149.8681 - fp: 3445.2308 - tn: 90525.9780 - fn: 19.4945 - accuracy: 0.9633 - precision: 0.0402 - recall: 0.8654 - auc: 0.9502 - val_loss: 0.0869 - val_tp: 67.0000 - val_fp: 926.0000 - val_tn: 44568.0000 - val_fn: 8.0000 - val_accuracy: 0.9795 - val_precision: 0.0675 - val_recall: 0.8933 - val_auc: 0.9763
Epoch 20/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1946 - tp: 152.5604 - fp: 3088.7253 - tn: 90884.3297 - fn: 14.9560 - accuracy: 0.9668 - precision: 0.0451 - recall: 0.9082 - auc: 0.9733 - val_loss: 0.0790 - val_tp: 67.0000 - val_fp: 837.0000 - val_tn: 44657.0000 - val_fn: 8.0000 - val_accuracy: 0.9815 - val_precision: 0.0741 - val_recall: 0.8933 - val_auc: 0.9767
Epoch 21/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1831 - tp: 165.5275 - fp: 3028.0659 - tn: 90932.2308 - fn: 14.7473 - accuracy: 0.9677 - precision: 0.0534 - recall: 0.9309 - auc: 0.9819 - val_loss: 0.0817 - val_tp: 67.0000 - val_fp: 864.0000 - val_tn: 44630.0000 - val_fn: 8.0000 - val_accuracy: 0.9809 - val_precision: 0.0720 - val_recall: 0.8933 - val_auc: 0.9805
Epoch 22/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1915 - tp: 152.4835 - fp: 2960.4286 - tn: 91013.3407 - fn: 14.3187 - accuracy: 0.9684 - precision: 0.0490 - recall: 0.9172 - auc: 0.9748 - val_loss: 0.0799 - val_tp: 67.0000 - val_fp: 843.0000 - val_tn: 44651.0000 - val_fn: 8.0000 - val_accuracy: 0.9813 - val_precision: 0.0736 - val_recall: 0.8933 - val_auc: 0.9808
Epoch 23/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2210 - tp: 153.8791 - fp: 3078.8132 - tn: 90889.9670 - fn: 17.9121 - accuracy: 0.9667 - precision: 0.0474 - recall: 0.8988 - auc: 0.9721 - val_loss: 0.0806 - val_tp: 67.0000 - val_fp: 851.0000 - val_tn: 44643.0000 - val_fn: 8.0000 - val_accuracy: 0.9811 - val_precision: 0.0730 - val_recall: 0.8933 - val_auc: 0.9807
Epoch 24/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1972 - tp: 156.9780 - fp: 2981.0549 - tn: 90989.5934 - fn: 12.9451 - accuracy: 0.9679 - precision: 0.0528 - recall: 0.9299 - auc: 0.9742 - val_loss: 0.0762 - val_tp: 67.0000 - val_fp: 813.0000 - val_tn: 44681.0000 - val_fn: 8.0000 - val_accuracy: 0.9820 - val_precision: 0.0761 - val_recall: 0.8933 - val_auc: 0.9816
Epoch 25/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2061 - tp: 157.1758 - fp: 2973.5385 - tn: 90994.4725 - fn: 15.3846 - accuracy: 0.9685 - precision: 0.0513 - recall: 0.9128 - auc: 0.9749 - val_loss: 0.0785 - val_tp: 67.0000 - val_fp: 822.0000 - val_tn: 44672.0000 - val_fn: 8.0000 - val_accuracy: 0.9818 - val_precision: 0.0754 - val_recall: 0.8933 - val_auc: 0.9786
Epoch 26/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1855 - tp: 155.5275 - fp: 3008.1868 - tn: 90960.3516 - fn: 16.5055 - accuracy: 0.9680 - precision: 0.0496 - recall: 0.9169 - auc: 0.9806 - val_loss: 0.0830 - val_tp: 67.0000 - val_fp: 864.0000 - val_tn: 44630.0000 - val_fn: 8.0000 - val_accuracy: 0.9809 - val_precision: 0.0720 - val_recall: 0.8933 - val_auc: 0.9788
Epoch 27/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1958 - tp: 142.0220 - fp: 3216.1429 - tn: 90764.2088 - fn: 18.1978 - accuracy: 0.9657 - precision: 0.0414 - recall: 0.8806 - auc: 0.9791 - val_loss: 0.0842 - val_tp: 67.0000 - val_fp: 869.0000 - val_tn: 44625.0000 - val_fn: 8.0000 - val_accuracy: 0.9808 - val_precision: 0.0716 - val_recall: 0.8933 - val_auc: 0.9787
Epoch 28/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2145 - tp: 157.5495 - fp: 3212.0110 - tn: 90753.5055 - fn: 17.5055 - accuracy: 0.9655 - precision: 0.0486 - recall: 0.9063 - auc: 0.9764 - val_loss: 0.0870 - val_tp: 67.0000 - val_fp: 900.0000 - val_tn: 44594.0000 - val_fn: 8.0000 - val_accuracy: 0.9801 - val_precision: 0.0693 - val_recall: 0.8933 - val_auc: 0.9786
Epoch 29/100
90/90 [==============================] - 1s 10ms/step - loss: 0.1332 - tp: 152.7033 - fp: 2984.3407 - tn: 90994.1758 - fn: 9.3516 - accuracy: 0.9679 - precision: 0.0477 - recall: 0.9572 - auc: 0.9846 - val_loss: 0.0779 - val_tp: 67.0000 - val_fp: 806.0000 - val_tn: 44688.0000 - val_fn: 8.0000 - val_accuracy: 0.9821 - val_precision: 0.0767 - val_recall: 0.8933 - val_auc: 0.9790
Epoch 30/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2137 - tp: 162.0989 - fp: 2811.2418 - tn: 91154.1758 - fn: 13.0549 - accuracy: 0.9701 - precision: 0.0569 - recall: 0.9310 - auc: 0.9706 - val_loss: 0.0750 - val_tp: 67.0000 - val_fp: 778.0000 - val_tn: 44716.0000 - val_fn: 8.0000 - val_accuracy: 0.9828 - val_precision: 0.0793 - val_recall: 0.8933 - val_auc: 0.9790
Epoch 31/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1901 - tp: 151.2198 - fp: 2760.6813 - tn: 91211.5055 - fn: 17.1648 - accuracy: 0.9707 - precision: 0.0501 - recall: 0.9067 - auc: 0.9759 - val_loss: 0.0760 - val_tp: 67.0000 - val_fp: 774.0000 - val_tn: 44720.0000 - val_fn: 8.0000 - val_accuracy: 0.9828 - val_precision: 0.0797 - val_recall: 0.8933 - val_auc: 0.9790
Epoch 32/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1391 - tp: 151.6154 - fp: 2790.6703 - tn: 91185.2857 - fn: 13.0000 - accuracy: 0.9702 - precision: 0.0503 - recall: 0.9317 - auc: 0.9878 - val_loss: 0.0794 - val_tp: 67.0000 - val_fp: 811.0000 - val_tn: 44683.0000 - val_fn: 8.0000 - val_accuracy: 0.9820 - val_precision: 0.0763 - val_recall: 0.8933 - val_auc: 0.9793
Epoch 33/100
90/90 [==============================] - 1s 8ms/step - loss: 0.1739 - tp: 156.7473 - fp: 2926.8901 - tn: 91046.4066 - fn: 10.5275 - accuracy: 0.9689 - precision: 0.0515 - recall: 0.9423 - auc: 0.9798 - val_loss: 0.0772 - val_tp: 67.0000 - val_fp: 789.0000 - val_tn: 44705.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0783 - val_recall: 0.8933 - val_auc: 0.9792
Epoch 34/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2334 - tp: 152.7253 - fp: 2836.5055 - tn: 91137.6264 - fn: 13.7143 - accuracy: 0.9693 - precision: 0.0503 - recall: 0.9034 - auc: 0.9683 - val_loss: 0.0783 - val_tp: 67.0000 - val_fp: 804.0000 - val_tn: 44690.0000 - val_fn: 8.0000 - val_accuracy: 0.9822 - val_precision: 0.0769 - val_recall: 0.8933 - val_auc: 0.9792
Restoring model weights from the end of the best epoch.
Epoch 00034: early stopping

학습 이력 조회

plot_metrics(weighted_history)

png

매트릭 평가

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.07444514334201813
tp :  73.0
fp :  1005.0
tn :  55870.0
fn :  14.0
accuracy :  0.982110857963562
precision :  0.06771799921989441
recall :  0.8390804529190063
auc :  0.9624707102775574

Legitimate Transactions Detected (True Negatives):  55870
Legitimate Transactions Incorrectly Detected (False Positives):  1005
Fraudulent Transactions Missed (False Negatives):  14
Fraudulent Transactions Detected (True Positives):  73
Total Fraudulent Transactions:  87

png

여기서 클래스 가중치를 사용하면 거짓 긍정이 더 많기 때문에 정확도와 정밀도가 낮다는 것을 알 수 있지만, 반대로 리콜과 AUC는 참 긍정이 더 많은 모델입니다. 정확도가 낮음에도 불구하고 이 모델은 리콜이 더 높습니다.(그리고 더 많은 부정 거래를 식별한다.) 물론 두 가지 유형의 오류에는 모두 비용이 발생합니다.(너무 많은 합법적인 거래를 사기로 표시하여 사용자를 괴롭히는 것을 원하지 않을 것입니다.) 응용 프로그램에 대하여 이러한 다양한 유형의 오류 간의 절충을 신중하게 고려하십시오.

ROC 플로팅

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7ff3703d37f0>

png

오버샘플링

소수 계급 과대 표본

관련된 접근 방식은 소수 클래스를 오버 샘플링 하여 데이터 세트를 리 샘플링 하는 것입니다.

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

NumPy 사용

긍정적인 예에서 적절한 수의 임의 인덱스를 선택하여 데이터 세트의 균형을 수동으로 조정할 수 있습니다.:

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181946, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363892, 29)

tf.data 사용

tf.data 사용하는 경우 균형 잡힌 예제를 생성하는 가장 쉬운 방법은 positive 그리고 negative 데이터 세트로 시작하여 병합하는 것입니다. 더 많은 예는 tf.data guide 를 참조하세요.

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

각 데이터 세트는 (feature, label) 쌍을 제공합니다.:

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [-1.71716559  1.47005786 -2.47738985  0.19498166 -1.69346885 -1.480288
 -2.50054314 -1.41973413  0.16941577 -4.57371112  3.57266458 -5.
 -1.09645923 -5.          1.56465042 -5.         -5.         -3.35739236
  1.0754794  -0.04924097  2.86749363 -0.70728706 -0.40225951  0.21121368
  0.19880498 -2.18727931 -3.03575199 -3.05769118  0.75924124]

Label:  1

experimental.sample_from_datasets 를 사용하여 두 가지를 병합합니다.:

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.48779296875

이 데이터 세트를 사용하려면 epoch 당 단계 수가 필요합니다.

이 경우 "epoch" 의 정의는 명확하지 않습니다. 각 부정적인 예를 한번 볼 때 필요한 배치 수라고 가정합니다.:

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

오버 샘플링 된 데이터에 대한 학습

이제 클래스 가중치를 사용하는 대신 리 샘플링 된 데이터 세트로 모델을 학습하여 이러한 방법이 어떻게 비교되는지 확인하십시오.

참고: 긍정적인 예를 복제하여 데이터가 균형을 이루었기 때문에 총 데이터 세트 크기가 더 크고 각 세대가 더 많은 학습 단계를 위해 실행됩니다.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks=[early_stopping],
    validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 9s 26ms/step - loss: 0.5210 - tp: 121561.5054 - fp: 41117.0430 - tn: 158732.2222 - fn: 22263.8889 - accuracy: 0.8122 - precision: 0.7057 - recall: 0.8153 - auc: 0.9032 - val_loss: 0.2060 - val_tp: 66.0000 - val_fp: 865.0000 - val_tn: 44629.0000 - val_fn: 9.0000 - val_accuracy: 0.9808 - val_precision: 0.0709 - val_recall: 0.8800 - val_auc: 0.9745
Epoch 2/100
278/278 [==============================] - 7s 24ms/step - loss: 0.1969 - tp: 130064.3047 - fp: 5756.9534 - tn: 137217.6487 - fn: 13673.7527 - accuracy: 0.9305 - precision: 0.9550 - recall: 0.9039 - auc: 0.9736 - val_loss: 0.1000 - val_tp: 66.0000 - val_fp: 703.0000 - val_tn: 44791.0000 - val_fn: 9.0000 - val_accuracy: 0.9844 - val_precision: 0.0858 - val_recall: 0.8800 - val_auc: 0.9791
Epoch 3/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1445 - tp: 131797.2867 - fp: 4093.8889 - tn: 139202.3082 - fn: 11619.1756 - accuracy: 0.9443 - precision: 0.9697 - recall: 0.9173 - auc: 0.9864 - val_loss: 0.0740 - val_tp: 67.0000 - val_fp: 706.0000 - val_tn: 44788.0000 - val_fn: 8.0000 - val_accuracy: 0.9843 - val_precision: 0.0867 - val_recall: 0.8933 - val_auc: 0.9813
Epoch 4/100
278/278 [==============================] - 7s 26ms/step - loss: 0.1201 - tp: 133345.2688 - fp: 3828.1900 - tn: 139604.8889 - fn: 9934.3118 - accuracy: 0.9518 - precision: 0.9721 - recall: 0.9300 - auc: 0.9911 - val_loss: 0.0624 - val_tp: 67.0000 - val_fp: 687.0000 - val_tn: 44807.0000 - val_fn: 8.0000 - val_accuracy: 0.9847 - val_precision: 0.0889 - val_recall: 0.8933 - val_auc: 0.9815
Epoch 5/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1074 - tp: 134843.6452 - fp: 3568.9821 - tn: 139417.5663 - fn: 8882.4659 - accuracy: 0.9563 - precision: 0.9741 - recall: 0.9377 - auc: 0.9933 - val_loss: 0.0561 - val_tp: 67.0000 - val_fp: 705.0000 - val_tn: 44789.0000 - val_fn: 8.0000 - val_accuracy: 0.9844 - val_precision: 0.0868 - val_recall: 0.8933 - val_auc: 0.9829
Epoch 6/100
278/278 [==============================] - 7s 24ms/step - loss: 0.0984 - tp: 135411.1613 - fp: 3495.1183 - tn: 139764.0753 - fn: 8042.3047 - accuracy: 0.9596 - precision: 0.9746 - recall: 0.9438 - auc: 0.9944 - val_loss: 0.0499 - val_tp: 67.0000 - val_fp: 660.0000 - val_tn: 44834.0000 - val_fn: 8.0000 - val_accuracy: 0.9853 - val_precision: 0.0922 - val_recall: 0.8933 - val_auc: 0.9831
Epoch 7/100
278/278 [==============================] - 7s 24ms/step - loss: 0.0917 - tp: 136129.1254 - fp: 3363.4695 - tn: 139901.5090 - fn: 7318.5556 - accuracy: 0.9624 - precision: 0.9758 - recall: 0.9485 - auc: 0.9952 - val_loss: 0.0457 - val_tp: 67.0000 - val_fp: 635.0000 - val_tn: 44859.0000 - val_fn: 8.0000 - val_accuracy: 0.9859 - val_precision: 0.0954 - val_recall: 0.8933 - val_auc: 0.9792
Epoch 8/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0868 - tp: 136650.6022 - fp: 3340.2581 - tn: 139912.6738 - fn: 6809.1254 - accuracy: 0.9644 - precision: 0.9762 - recall: 0.9520 - auc: 0.9959 - val_loss: 0.0417 - val_tp: 68.0000 - val_fp: 602.0000 - val_tn: 44892.0000 - val_fn: 7.0000 - val_accuracy: 0.9866 - val_precision: 0.1015 - val_recall: 0.9067 - val_auc: 0.9799
Epoch 9/100
278/278 [==============================] - 7s 24ms/step - loss: 0.0819 - tp: 137372.2975 - fp: 3316.8996 - tn: 139692.0573 - fn: 6331.4050 - accuracy: 0.9662 - precision: 0.9764 - recall: 0.9557 - auc: 0.9964 - val_loss: 0.0394 - val_tp: 68.0000 - val_fp: 589.0000 - val_tn: 44905.0000 - val_fn: 7.0000 - val_accuracy: 0.9869 - val_precision: 0.1035 - val_recall: 0.9067 - val_auc: 0.9802
Epoch 10/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0768 - tp: 137785.7240 - fp: 3303.0108 - tn: 139765.3907 - fn: 5858.5341 - accuracy: 0.9681 - precision: 0.9766 - recall: 0.9592 - auc: 0.9968 - val_loss: 0.0371 - val_tp: 68.0000 - val_fp: 575.0000 - val_tn: 44919.0000 - val_fn: 7.0000 - val_accuracy: 0.9872 - val_precision: 0.1058 - val_recall: 0.9067 - val_auc: 0.9805
Epoch 11/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0742 - tp: 137693.1935 - fp: 3364.8495 - tn: 140214.9928 - fn: 5439.6237 - accuracy: 0.9692 - precision: 0.9762 - recall: 0.9618 - auc: 0.9970 - val_loss: 0.0351 - val_tp: 68.0000 - val_fp: 548.0000 - val_tn: 44946.0000 - val_fn: 7.0000 - val_accuracy: 0.9878 - val_precision: 0.1104 - val_recall: 0.9067 - val_auc: 0.9806
Epoch 12/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0711 - tp: 138494.5376 - fp: 3374.0323 - tn: 139773.1541 - fn: 5070.9355 - accuracy: 0.9704 - precision: 0.9762 - recall: 0.9645 - auc: 0.9972 - val_loss: 0.0325 - val_tp: 68.0000 - val_fp: 508.0000 - val_tn: 44986.0000 - val_fn: 7.0000 - val_accuracy: 0.9887 - val_precision: 0.1181 - val_recall: 0.9067 - val_auc: 0.9808
Epoch 13/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0679 - tp: 138309.3513 - fp: 3278.0323 - tn: 140222.7993 - fn: 4902.4767 - accuracy: 0.9714 - precision: 0.9768 - recall: 0.9655 - auc: 0.9974 - val_loss: 0.0311 - val_tp: 68.0000 - val_fp: 504.0000 - val_tn: 44990.0000 - val_fn: 7.0000 - val_accuracy: 0.9888 - val_precision: 0.1189 - val_recall: 0.9067 - val_auc: 0.9808
Epoch 14/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0666 - tp: 138519.0502 - fp: 3255.0502 - tn: 140034.3369 - fn: 4904.2222 - accuracy: 0.9713 - precision: 0.9769 - recall: 0.9656 - auc: 0.9974 - val_loss: 0.0301 - val_tp: 67.0000 - val_fp: 490.0000 - val_tn: 45004.0000 - val_fn: 8.0000 - val_accuracy: 0.9891 - val_precision: 0.1203 - val_recall: 0.8933 - val_auc: 0.9808
Epoch 15/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0633 - tp: 139042.3333 - fp: 3276.7455 - tn: 139769.9176 - fn: 4623.6631 - accuracy: 0.9723 - precision: 0.9769 - recall: 0.9676 - auc: 0.9977 - val_loss: 0.0284 - val_tp: 67.0000 - val_fp: 468.0000 - val_tn: 45026.0000 - val_fn: 8.0000 - val_accuracy: 0.9896 - val_precision: 0.1252 - val_recall: 0.8933 - val_auc: 0.9812
Epoch 16/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0625 - tp: 139156.4480 - fp: 3319.3943 - tn: 139969.8244 - fn: 4266.9928 - accuracy: 0.9735 - precision: 0.9766 - recall: 0.9703 - auc: 0.9977 - val_loss: 0.0266 - val_tp: 67.0000 - val_fp: 455.0000 - val_tn: 45039.0000 - val_fn: 8.0000 - val_accuracy: 0.9898 - val_precision: 0.1284 - val_recall: 0.8933 - val_auc: 0.9751
Restoring model weights from the end of the best epoch.
Epoch 00016: early stopping

만약 훈련 프로세스가 각 기울기 업데이트에서 전체 데이터 세트를 고려하는 경우, 이 오버 샘플링은 기본적으로 클래스 가중치와 동일합니다.

그러나 여기에서 한 것처럼 모델을 배치 방식으로 훈련 할 때 오버 샘플링 된 데이터는 더 부드러운 기울기 신호를 제공합니다. 각각의 긍정적인 예가 큰 가중치를 가진 하나의 배치로 표시되는 대신, 그것들은 작은 가중치로 매 회 많은 다른 배치로 보여집니다.

이 부드러운 기울기 신호는 모델을 더 쉽게 훈련 할 수 있습니다.

교육 이력 확인

학습 데이터의 분포가 검증 및 테스트 데이터와 완전히 다르기 때문에 여기서 측정 항목의 분포가 다를 수 있습니다.

plot_metrics(resampled_history)

png

재교육

균형 잡힌 데이터에 대한 훈련이 더 쉽기 때문에 위의 훈련 절차가 빠르게 과적합 될 수 있습니다.

따라서 epochs를 분리하여 callbacks.EarlyStopping을 제공하십시오.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch=20,
    epochs=10*EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_ds))
Epoch 1/1000
20/20 [==============================] - 3s 64ms/step - loss: 0.8553 - tp: 7544.1429 - fp: 5543.8095 - tn: 51153.8571 - fn: 3757.6667 - accuracy: 0.8716 - precision: 0.5583 - recall: 0.6545 - auc: 0.9162 - val_loss: 0.6783 - val_tp: 70.0000 - val_fp: 18281.0000 - val_tn: 27213.0000 - val_fn: 5.0000 - val_accuracy: 0.5987 - val_precision: 0.0038 - val_recall: 0.9333 - val_auc: 0.9343
Epoch 2/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.5958 - tp: 8912.1905 - fp: 4871.4762 - tn: 6467.1905 - fn: 2179.6190 - accuracy: 0.6800 - precision: 0.6406 - recall: 0.7979 - auc: 0.8074 - val_loss: 0.6313 - val_tp: 71.0000 - val_fp: 15726.0000 - val_tn: 29768.0000 - val_fn: 4.0000 - val_accuracy: 0.6548 - val_precision: 0.0045 - val_recall: 0.9467 - val_auc: 0.9452
Epoch 3/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.5067 - tp: 9444.6667 - fp: 4316.9524 - tn: 6978.1905 - fn: 1690.6667 - accuracy: 0.7280 - precision: 0.6817 - recall: 0.8441 - auc: 0.8618 - val_loss: 0.5710 - val_tp: 70.0000 - val_fp: 12275.0000 - val_tn: 33219.0000 - val_fn: 5.0000 - val_accuracy: 0.7305 - val_precision: 0.0057 - val_recall: 0.9333 - val_auc: 0.9505
Epoch 4/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.4389 - tp: 9821.0000 - fp: 3629.6190 - tn: 7496.9048 - fn: 1482.9524 - accuracy: 0.7695 - precision: 0.7279 - recall: 0.8684 - auc: 0.8942 - val_loss: 0.5116 - val_tp: 69.0000 - val_fp: 8797.0000 - val_tn: 36697.0000 - val_fn: 6.0000 - val_accuracy: 0.8068 - val_precision: 0.0078 - val_recall: 0.9200 - val_auc: 0.9550
Epoch 5/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.4016 - tp: 9779.7143 - fp: 3008.5714 - tn: 8206.9524 - fn: 1435.2381 - accuracy: 0.7993 - precision: 0.7615 - recall: 0.8710 - auc: 0.9086 - val_loss: 0.4603 - val_tp: 69.0000 - val_fp: 6221.0000 - val_tn: 39273.0000 - val_fn: 6.0000 - val_accuracy: 0.8634 - val_precision: 0.0110 - val_recall: 0.9200 - val_auc: 0.9591
Epoch 6/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.3720 - tp: 9862.1905 - fp: 2487.6667 - tn: 8702.4286 - fn: 1378.1905 - accuracy: 0.8250 - precision: 0.7945 - recall: 0.8763 - auc: 0.9198 - val_loss: 0.4176 - val_tp: 69.0000 - val_fp: 4449.0000 - val_tn: 41045.0000 - val_fn: 6.0000 - val_accuracy: 0.9022 - val_precision: 0.0153 - val_recall: 0.9200 - val_auc: 0.9629
Epoch 7/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.3434 - tp: 9884.2857 - fp: 2132.1429 - tn: 9100.8571 - fn: 1313.1905 - accuracy: 0.8445 - precision: 0.8204 - recall: 0.8818 - auc: 0.9310 - val_loss: 0.3800 - val_tp: 69.0000 - val_fp: 3168.0000 - val_tn: 42326.0000 - val_fn: 6.0000 - val_accuracy: 0.9303 - val_precision: 0.0213 - val_recall: 0.9200 - val_auc: 0.9657
Epoch 8/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.3185 - tp: 9896.0952 - fp: 1687.2381 - tn: 9618.8095 - fn: 1228.3333 - accuracy: 0.8688 - precision: 0.8522 - recall: 0.8900 - auc: 0.9407 - val_loss: 0.3446 - val_tp: 69.0000 - val_fp: 2300.0000 - val_tn: 43194.0000 - val_fn: 6.0000 - val_accuracy: 0.9494 - val_precision: 0.0291 - val_recall: 0.9200 - val_auc: 0.9681
Epoch 9/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2999 - tp: 10051.6190 - fp: 1368.1905 - tn: 9791.8571 - fn: 1218.8095 - accuracy: 0.8839 - precision: 0.8790 - recall: 0.8916 - auc: 0.9447 - val_loss: 0.3133 - val_tp: 67.0000 - val_fp: 1682.0000 - val_tn: 43812.0000 - val_fn: 8.0000 - val_accuracy: 0.9629 - val_precision: 0.0383 - val_recall: 0.8933 - val_auc: 0.9705
Epoch 10/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2823 - tp: 10050.8571 - fp: 1147.0476 - tn: 10040.3810 - fn: 1192.1905 - accuracy: 0.8952 - precision: 0.8972 - recall: 0.8938 - auc: 0.9503 - val_loss: 0.2866 - val_tp: 67.0000 - val_fp: 1332.0000 - val_tn: 44162.0000 - val_fn: 8.0000 - val_accuracy: 0.9706 - val_precision: 0.0479 - val_recall: 0.8933 - val_auc: 0.9720
Epoch 11/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2664 - tp: 10064.9048 - fp: 943.3333 - tn: 10248.6190 - fn: 1173.6190 - accuracy: 0.9048 - precision: 0.9120 - recall: 0.8965 - auc: 0.9546 - val_loss: 0.2625 - val_tp: 67.0000 - val_fp: 1139.0000 - val_tn: 44355.0000 - val_fn: 8.0000 - val_accuracy: 0.9748 - val_precision: 0.0556 - val_recall: 0.8933 - val_auc: 0.9728
Epoch 12/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.2483 - tp: 10024.4762 - fp: 803.2381 - tn: 10445.5714 - fn: 1157.1905 - accuracy: 0.9122 - precision: 0.9257 - recall: 0.8961 - auc: 0.9603 - val_loss: 0.2407 - val_tp: 67.0000 - val_fp: 1033.0000 - val_tn: 44461.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0609 - val_recall: 0.8933 - val_auc: 0.9733
Epoch 13/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2410 - tp: 10077.0000 - fp: 735.9524 - tn: 10467.8095 - fn: 1149.7143 - accuracy: 0.9155 - precision: 0.9310 - recall: 0.8977 - auc: 0.9620 - val_loss: 0.2206 - val_tp: 67.0000 - val_fp: 930.0000 - val_tn: 44564.0000 - val_fn: 8.0000 - val_accuracy: 0.9794 - val_precision: 0.0672 - val_recall: 0.8933 - val_auc: 0.9741
Epoch 14/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2272 - tp: 10106.8571 - fp: 612.7143 - tn: 10571.2857 - fn: 1139.6190 - accuracy: 0.9214 - precision: 0.9418 - recall: 0.8988 - auc: 0.9657 - val_loss: 0.2042 - val_tp: 66.0000 - val_fp: 862.0000 - val_tn: 44632.0000 - val_fn: 9.0000 - val_accuracy: 0.9809 - val_precision: 0.0711 - val_recall: 0.8800 - val_auc: 0.9746
Epoch 15/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2199 - tp: 10124.0952 - fp: 576.3810 - tn: 10610.2857 - fn: 1119.7143 - accuracy: 0.9245 - precision: 0.9456 - recall: 0.9009 - auc: 0.9679 - val_loss: 0.1896 - val_tp: 66.0000 - val_fp: 835.0000 - val_tn: 44659.0000 - val_fn: 9.0000 - val_accuracy: 0.9815 - val_precision: 0.0733 - val_recall: 0.8800 - val_auc: 0.9752
Epoch 16/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2083 - tp: 10153.4762 - fp: 542.4286 - tn: 10636.2381 - fn: 1098.3333 - accuracy: 0.9273 - precision: 0.9491 - recall: 0.9038 - auc: 0.9711 - val_loss: 0.1764 - val_tp: 66.0000 - val_fp: 802.0000 - val_tn: 44692.0000 - val_fn: 9.0000 - val_accuracy: 0.9822 - val_precision: 0.0760 - val_recall: 0.8800 - val_auc: 0.9752
Epoch 17/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2030 - tp: 10231.3333 - fp: 491.6667 - tn: 10618.0000 - fn: 1089.4762 - accuracy: 0.9290 - precision: 0.9535 - recall: 0.9039 - auc: 0.9722 - val_loss: 0.1652 - val_tp: 66.0000 - val_fp: 777.0000 - val_tn: 44717.0000 - val_fn: 9.0000 - val_accuracy: 0.9828 - val_precision: 0.0783 - val_recall: 0.8800 - val_auc: 0.9754
Epoch 18/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1973 - tp: 10162.3810 - fp: 468.5714 - tn: 10724.2381 - fn: 1075.2857 - accuracy: 0.9306 - precision: 0.9564 - recall: 0.9034 - auc: 0.9734 - val_loss: 0.1553 - val_tp: 66.0000 - val_fp: 763.0000 - val_tn: 44731.0000 - val_fn: 9.0000 - val_accuracy: 0.9831 - val_precision: 0.0796 - val_recall: 0.8800 - val_auc: 0.9761
Epoch 19/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1903 - tp: 10121.1905 - fp: 421.9048 - tn: 10824.8095 - fn: 1062.5714 - accuracy: 0.9331 - precision: 0.9602 - recall: 0.9037 - auc: 0.9752 - val_loss: 0.1456 - val_tp: 66.0000 - val_fp: 729.0000 - val_tn: 44765.0000 - val_fn: 9.0000 - val_accuracy: 0.9838 - val_precision: 0.0830 - val_recall: 0.8800 - val_auc: 0.9764
Epoch 20/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1877 - tp: 10224.0476 - fp: 388.5714 - tn: 10744.8571 - fn: 1073.0000 - accuracy: 0.9344 - precision: 0.9636 - recall: 0.9046 - auc: 0.9763 - val_loss: 0.1381 - val_tp: 66.0000 - val_fp: 728.0000 - val_tn: 44766.0000 - val_fn: 9.0000 - val_accuracy: 0.9838 - val_precision: 0.0831 - val_recall: 0.8800 - val_auc: 0.9767
Epoch 21/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1758 - tp: 10158.1429 - fp: 391.3810 - tn: 10873.1429 - fn: 1007.8095 - accuracy: 0.9379 - precision: 0.9633 - recall: 0.9097 - auc: 0.9789 - val_loss: 0.1309 - val_tp: 66.0000 - val_fp: 727.0000 - val_tn: 44767.0000 - val_fn: 9.0000 - val_accuracy: 0.9838 - val_precision: 0.0832 - val_recall: 0.8800 - val_auc: 0.9772
Epoch 22/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1725 - tp: 10208.1905 - fp: 398.1429 - tn: 10781.5238 - fn: 1042.6190 - accuracy: 0.9360 - precision: 0.9625 - recall: 0.9077 - auc: 0.9800 - val_loss: 0.1249 - val_tp: 66.0000 - val_fp: 730.0000 - val_tn: 44764.0000 - val_fn: 9.0000 - val_accuracy: 0.9838 - val_precision: 0.0829 - val_recall: 0.8800 - val_auc: 0.9774
Epoch 23/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1704 - tp: 10120.0952 - fp: 386.0476 - tn: 10890.6667 - fn: 1033.6667 - accuracy: 0.9364 - precision: 0.9632 - recall: 0.9067 - auc: 0.9802 - val_loss: 0.1197 - val_tp: 66.0000 - val_fp: 728.0000 - val_tn: 44766.0000 - val_fn: 9.0000 - val_accuracy: 0.9838 - val_precision: 0.0831 - val_recall: 0.8800 - val_auc: 0.9778
Epoch 24/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1665 - tp: 10213.6190 - fp: 371.2381 - tn: 10848.7143 - fn: 996.9048 - accuracy: 0.9389 - precision: 0.9656 - recall: 0.9102 - auc: 0.9816 - val_loss: 0.1151 - val_tp: 66.0000 - val_fp: 734.0000 - val_tn: 44760.0000 - val_fn: 9.0000 - val_accuracy: 0.9837 - val_precision: 0.0825 - val_recall: 0.8800 - val_auc: 0.9783
Epoch 25/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1636 - tp: 10314.7143 - fp: 370.1905 - tn: 10772.6667 - fn: 972.9048 - accuracy: 0.9401 - precision: 0.9662 - recall: 0.9130 - auc: 0.9820 - val_loss: 0.1102 - val_tp: 66.0000 - val_fp: 718.0000 - val_tn: 44776.0000 - val_fn: 9.0000 - val_accuracy: 0.9840 - val_precision: 0.0842 - val_recall: 0.8800 - val_auc: 0.9789
Epoch 26/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1578 - tp: 10331.5714 - fp: 373.0952 - tn: 10761.1905 - fn: 964.6190 - accuracy: 0.9402 - precision: 0.9651 - recall: 0.9143 - auc: 0.9838 - val_loss: 0.1060 - val_tp: 66.0000 - val_fp: 710.0000 - val_tn: 44784.0000 - val_fn: 9.0000 - val_accuracy: 0.9842 - val_precision: 0.0851 - val_recall: 0.8800 - val_auc: 0.9787
Epoch 27/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1549 - tp: 10324.9524 - fp: 341.7619 - tn: 10831.1905 - fn: 932.5714 - accuracy: 0.9431 - precision: 0.9677 - recall: 0.9171 - auc: 0.9840 - val_loss: 0.1025 - val_tp: 66.0000 - val_fp: 708.0000 - val_tn: 44786.0000 - val_fn: 9.0000 - val_accuracy: 0.9843 - val_precision: 0.0853 - val_recall: 0.8800 - val_auc: 0.9791
Epoch 28/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.1514 - tp: 10375.1429 - fp: 323.7143 - tn: 10845.5714 - fn: 886.0476 - accuracy: 0.9457 - precision: 0.9694 - recall: 0.9212 - auc: 0.9846 - val_loss: 0.0999 - val_tp: 66.0000 - val_fp: 712.0000 - val_tn: 44782.0000 - val_fn: 9.0000 - val_accuracy: 0.9842 - val_precision: 0.0848 - val_recall: 0.8800 - val_auc: 0.9791
Epoch 29/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1483 - tp: 10390.6667 - fp: 317.6190 - tn: 10799.4286 - fn: 922.7619 - accuracy: 0.9447 - precision: 0.9708 - recall: 0.9182 - auc: 0.9854 - val_loss: 0.0973 - val_tp: 66.0000 - val_fp: 718.0000 - val_tn: 44776.0000 - val_fn: 9.0000 - val_accuracy: 0.9840 - val_precision: 0.0842 - val_recall: 0.8800 - val_auc: 0.9794
Epoch 30/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1463 - tp: 10395.8095 - fp: 331.7143 - tn: 10816.5238 - fn: 886.4286 - accuracy: 0.9460 - precision: 0.9697 - recall: 0.9214 - auc: 0.9864 - val_loss: 0.0937 - val_tp: 66.0000 - val_fp: 714.0000 - val_tn: 44780.0000 - val_fn: 9.0000 - val_accuracy: 0.9841 - val_precision: 0.0846 - val_recall: 0.8800 - val_auc: 0.9797
Epoch 31/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.1419 - tp: 10402.0952 - fp: 317.3810 - tn: 10827.2381 - fn: 883.7619 - accuracy: 0.9464 - precision: 0.9703 - recall: 0.9219 - auc: 0.9867 - val_loss: 0.0908 - val_tp: 66.0000 - val_fp: 709.0000 - val_tn: 44785.0000 - val_fn: 9.0000 - val_accuracy: 0.9842 - val_precision: 0.0852 - val_recall: 0.8800 - val_auc: 0.9802
Epoch 32/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1426 - tp: 10358.5238 - fp: 294.6190 - tn: 10879.2857 - fn: 898.0476 - accuracy: 0.9470 - precision: 0.9727 - recall: 0.9200 - auc: 0.9864 - val_loss: 0.0885 - val_tp: 66.0000 - val_fp: 701.0000 - val_tn: 44793.0000 - val_fn: 9.0000 - val_accuracy: 0.9844 - val_precision: 0.0860 - val_recall: 0.8800 - val_auc: 0.9805
Epoch 33/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1406 - tp: 10374.8095 - fp: 340.7143 - tn: 10850.9524 - fn: 864.0000 - accuracy: 0.9463 - precision: 0.9677 - recall: 0.9237 - auc: 0.9873 - val_loss: 0.0866 - val_tp: 67.0000 - val_fp: 706.0000 - val_tn: 44788.0000 - val_fn: 8.0000 - val_accuracy: 0.9843 - val_precision: 0.0867 - val_recall: 0.8933 - val_auc: 0.9794
Epoch 34/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1361 - tp: 10266.7143 - fp: 300.0952 - tn: 11002.8095 - fn: 860.8571 - accuracy: 0.9489 - precision: 0.9727 - recall: 0.9227 - auc: 0.9878 - val_loss: 0.0847 - val_tp: 67.0000 - val_fp: 711.0000 - val_tn: 44783.0000 - val_fn: 8.0000 - val_accuracy: 0.9842 - val_precision: 0.0861 - val_recall: 0.8933 - val_auc: 0.9799
Epoch 35/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1353 - tp: 10441.9524 - fp: 314.0000 - tn: 10797.2857 - fn: 877.2381 - accuracy: 0.9466 - precision: 0.9708 - recall: 0.9221 - auc: 0.9881 - val_loss: 0.0835 - val_tp: 67.0000 - val_fp: 715.0000 - val_tn: 44779.0000 - val_fn: 8.0000 - val_accuracy: 0.9841 - val_precision: 0.0857 - val_recall: 0.8933 - val_auc: 0.9803
Epoch 36/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.1361 - tp: 10322.4762 - fp: 298.2857 - tn: 10957.6667 - fn: 852.0476 - accuracy: 0.9490 - precision: 0.9723 - recall: 0.9240 - auc: 0.9880 - val_loss: 0.0820 - val_tp: 67.0000 - val_fp: 719.0000 - val_tn: 44775.0000 - val_fn: 8.0000 - val_accuracy: 0.9840 - val_precision: 0.0852 - val_recall: 0.8933 - val_auc: 0.9804
Epoch 37/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1319 - tp: 10396.8571 - fp: 311.3810 - tn: 10876.4286 - fn: 845.8095 - accuracy: 0.9479 - precision: 0.9707 - recall: 0.9239 - auc: 0.9892 - val_loss: 0.0810 - val_tp: 67.0000 - val_fp: 725.0000 - val_tn: 44769.0000 - val_fn: 8.0000 - val_accuracy: 0.9839 - val_precision: 0.0846 - val_recall: 0.8933 - val_auc: 0.9808
Epoch 38/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1286 - tp: 10339.0000 - fp: 301.0476 - tn: 10960.5714 - fn: 829.8571 - accuracy: 0.9498 - precision: 0.9723 - recall: 0.9259 - auc: 0.9897 - val_loss: 0.0793 - val_tp: 67.0000 - val_fp: 722.0000 - val_tn: 44772.0000 - val_fn: 8.0000 - val_accuracy: 0.9840 - val_precision: 0.0849 - val_recall: 0.8933 - val_auc: 0.9808
Epoch 39/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1288 - tp: 10322.3333 - fp: 309.8095 - tn: 10955.0476 - fn: 843.2857 - accuracy: 0.9487 - precision: 0.9713 - recall: 0.9241 - auc: 0.9897 - val_loss: 0.0781 - val_tp: 67.0000 - val_fp: 724.0000 - val_tn: 44770.0000 - val_fn: 8.0000 - val_accuracy: 0.9839 - val_precision: 0.0847 - val_recall: 0.8933 - val_auc: 0.9808
Epoch 40/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1287 - tp: 10419.9048 - fp: 307.4286 - tn: 10876.5714 - fn: 826.5714 - accuracy: 0.9492 - precision: 0.9709 - recall: 0.9268 - auc: 0.9893 - val_loss: 0.0769 - val_tp: 67.0000 - val_fp: 724.0000 - val_tn: 44770.0000 - val_fn: 8.0000 - val_accuracy: 0.9839 - val_precision: 0.0847 - val_recall: 0.8933 - val_auc: 0.9812
Epoch 41/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1271 - tp: 10389.7619 - fp: 306.0000 - tn: 10920.8095 - fn: 813.9048 - accuracy: 0.9504 - precision: 0.9719 - recall: 0.9272 - auc: 0.9901 - val_loss: 0.0753 - val_tp: 67.0000 - val_fp: 717.0000 - val_tn: 44777.0000 - val_fn: 8.0000 - val_accuracy: 0.9841 - val_precision: 0.0855 - val_recall: 0.8933 - val_auc: 0.9816
Epoch 42/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1314 - tp: 10331.9048 - fp: 325.7143 - tn: 10955.4762 - fn: 817.3810 - accuracy: 0.9484 - precision: 0.9692 - recall: 0.9256 - auc: 0.9892 - val_loss: 0.0729 - val_tp: 67.0000 - val_fp: 691.0000 - val_tn: 44803.0000 - val_fn: 8.0000 - val_accuracy: 0.9847 - val_precision: 0.0884 - val_recall: 0.8933 - val_auc: 0.9815
Epoch 43/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1238 - tp: 10418.0476 - fp: 275.6190 - tn: 10904.6667 - fn: 832.1429 - accuracy: 0.9506 - precision: 0.9746 - recall: 0.9254 - auc: 0.9904 - val_loss: 0.0719 - val_tp: 67.0000 - val_fp: 691.0000 - val_tn: 44803.0000 - val_fn: 8.0000 - val_accuracy: 0.9847 - val_precision: 0.0884 - val_recall: 0.8933 - val_auc: 0.9819
Epoch 44/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1202 - tp: 10372.1905 - fp: 265.4286 - tn: 10982.9524 - fn: 809.9048 - accuracy: 0.9518 - precision: 0.9754 - recall: 0.9264 - auc: 0.9912 - val_loss: 0.0711 - val_tp: 67.0000 - val_fp: 696.0000 - val_tn: 44798.0000 - val_fn: 8.0000 - val_accuracy: 0.9846 - val_precision: 0.0878 - val_recall: 0.8933 - val_auc: 0.9817
Epoch 45/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1243 - tp: 10359.1905 - fp: 332.1905 - tn: 10949.0952 - fn: 790.0000 - accuracy: 0.9499 - precision: 0.9692 - recall: 0.9291 - auc: 0.9904 - val_loss: 0.0700 - val_tp: 67.0000 - val_fp: 696.0000 - val_tn: 44798.0000 - val_fn: 8.0000 - val_accuracy: 0.9846 - val_precision: 0.0878 - val_recall: 0.8933 - val_auc: 0.9797
Epoch 46/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1208 - tp: 10439.7143 - fp: 299.7143 - tn: 10922.0000 - fn: 769.0476 - accuracy: 0.9525 - precision: 0.9728 - recall: 0.9311 - auc: 0.9909 - val_loss: 0.0696 - val_tp: 67.0000 - val_fp: 706.0000 - val_tn: 44788.0000 - val_fn: 8.0000 - val_accuracy: 0.9843 - val_precision: 0.0867 - val_recall: 0.8933 - val_auc: 0.9798
Epoch 47/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1204 - tp: 10425.1905 - fp: 291.2381 - tn: 10959.9524 - fn: 754.0952 - accuracy: 0.9531 - precision: 0.9725 - recall: 0.9325 - auc: 0.9912 - val_loss: 0.0681 - val_tp: 67.0000 - val_fp: 683.0000 - val_tn: 44811.0000 - val_fn: 8.0000 - val_accuracy: 0.9848 - val_precision: 0.0893 - val_recall: 0.8933 - val_auc: 0.9795
Epoch 48/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.1150 - tp: 10465.5238 - fp: 267.5238 - tn: 10911.5238 - fn: 785.9048 - accuracy: 0.9526 - precision: 0.9751 - recall: 0.9297 - auc: 0.9919 - val_loss: 0.0681 - val_tp: 67.0000 - val_fp: 706.0000 - val_tn: 44788.0000 - val_fn: 8.0000 - val_accuracy: 0.9843 - val_precision: 0.0867 - val_recall: 0.8933 - val_auc: 0.9796
Epoch 49/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1167 - tp: 10359.1429 - fp: 305.0000 - tn: 10998.6667 - fn: 767.6667 - accuracy: 0.9522 - precision: 0.9721 - recall: 0.9306 - auc: 0.9917 - val_loss: 0.0675 - val_tp: 67.0000 - val_fp: 708.0000 - val_tn: 44786.0000 - val_fn: 8.0000 - val_accuracy: 0.9843 - val_precision: 0.0865 - val_recall: 0.8933 - val_auc: 0.9798
Epoch 50/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1151 - tp: 10490.0476 - fp: 304.0952 - tn: 10879.2857 - fn: 757.0476 - accuracy: 0.9529 - precision: 0.9717 - recall: 0.9330 - auc: 0.9921 - val_loss: 0.0667 - val_tp: 67.0000 - val_fp: 706.0000 - val_tn: 44788.0000 - val_fn: 8.0000 - val_accuracy: 0.9843 - val_precision: 0.0867 - val_recall: 0.8933 - val_auc: 0.9800
Epoch 51/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1158 - tp: 10527.9048 - fp: 288.8095 - tn: 10840.4762 - fn: 773.2857 - accuracy: 0.9525 - precision: 0.9736 - recall: 0.9313 - auc: 0.9920 - val_loss: 0.0661 - val_tp: 67.0000 - val_fp: 701.0000 - val_tn: 44793.0000 - val_fn: 8.0000 - val_accuracy: 0.9844 - val_precision: 0.0872 - val_recall: 0.8933 - val_auc: 0.9802
Epoch 52/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1161 - tp: 10471.1429 - fp: 263.7143 - tn: 10940.3333 - fn: 755.2857 - accuracy: 0.9547 - precision: 0.9751 - recall: 0.9333 - auc: 0.9921 - val_loss: 0.0653 - val_tp: 67.0000 - val_fp: 696.0000 - val_tn: 44798.0000 - val_fn: 8.0000 - val_accuracy: 0.9846 - val_precision: 0.0878 - val_recall: 0.8933 - val_auc: 0.9803
Epoch 53/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1137 - tp: 10386.5714 - fp: 310.0000 - tn: 10969.0000 - fn: 764.9048 - accuracy: 0.9516 - precision: 0.9704 - recall: 0.9308 - auc: 0.9921 - val_loss: 0.0644 - val_tp: 67.0000 - val_fp: 686.0000 - val_tn: 44808.0000 - val_fn: 8.0000 - val_accuracy: 0.9848 - val_precision: 0.0890 - val_recall: 0.8933 - val_auc: 0.9805
Restoring model weights from the end of the best epoch.
Epoch 00053: early stopping

훈련 이력 재확인

plot_metrics(resampled_history)

png

메트릭 평가

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
loss :  0.0707845687866211
tp :  74.0
fp :  828.0
tn :  56047.0
fn :  13.0
accuracy :  0.9852357506752014
precision :  0.08203991502523422
recall :  0.8505747318267822
auc :  0.944765031337738

Legitimate Transactions Detected (True Negatives):  56047
Legitimate Transactions Incorrectly Detected (False Positives):  828
Fraudulent Transactions Missed (False Negatives):  13
Fraudulent Transactions Detected (True Positives):  74
Total Fraudulent Transactions:  87

png

ROC 플로팅

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7ff2a836b860>

png

튜토리얼을 이 문제에 적용

불균형 데이터 분류는 학습 할 샘플이 너무 적기 때문에 본질적으로 어려운 작업입니다. 항상 데이터부터 시작하여 가능한 한 많은 샘플을 수집하고 모델이 소수 클래스를 최대한 활용할 수 있도록 어떤 기능이 관련 될 수 있는지에 대해 실질적인 생각을 하도록 최선을 다해야 합니다. 어떤 시점에서 모델은 원하는 결과를 개선하고 산출하는데 어려움을 겪을 수 있으므로 문제의 컨텍스트와 다양한 유형의 오류 간의 균형을 염두에 두는 것이 중요합니다.