이 페이지는 Cloud Translation API를 통해 번역되었습니다.
Switch to English

불균형 데이터 분류

TensorFlow.org에서보기 Google Colab에서 실행 GitHub에서 소스보기 노트북 다운로드

이 튜토리얼에서는 한 클래스의 예제 수가 다른 클래스의 예제보다 훨씬 큰 불균형 데이터 세트를 분류하는 방법을 보여줍니다. Kaggle에서 호스팅되는 신용 카드 사기 탐지 데이터 세트로 작업합니다. 총 284,807 건의 거래에서 492 건의 사기 거래 만 탐지하는 것이 목표입니다. Keras 를 사용하여 모델과 클래스 가중치 를 정의하여 모델이 불균형 데이터에서 학습하는 데 도움을줍니다. .

이 튜토리얼에는 다음을위한 완전한 코드가 포함되어 있습니다.

  • Pandas를 사용하여 CSV 파일을로드하십시오.
  • 교육, 검증 및 테스트 세트를 작성하십시오.
  • Keras (클래스 가중치 설정 포함)를 사용하여 모델을 정의하고 학습하십시오.
  • 다양한 메트릭 (정밀도 및 리콜 포함)을 사용하여 모델을 평가하십시오.
  • 불균형 데이터를 처리하는 일반적인 기술을 다음과 같이 시도하십시오.
    • 클래스 가중치
    • 오버 샘플링

설정

 import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
 
 mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
 

데이터 처리 및 탐색

Kaggle Credit Card Fraud 데이터 세트 다운로드

Pandas는 구조화 된 데이터를로드하고 작업하는 데 유용한 유틸리티가 많은 Python 라이브러리이며 CSV를 데이터 프레임으로 다운로드하는 데 사용할 수 있습니다.

 file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
 
 raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
 

클래스 라벨 불균형 검사

데이터 세트 불균형을 살펴 봅시다 :

 neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
 
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)


이것은 양성 샘플의 작은 부분을 보여줍니다.

데이터 정리, 분할 및 표준화

원시 데이터에는 몇 가지 문제가 있습니다. 먼저 TimeAmount 열이 너무 가변적 Amount 직접 사용할 수 없습니다. 의미 열이 없기 때문에 Time 열을 삭제하고 Amount 열의 로그를 사용하여 범위를 줄입니다.

 cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps=0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)
 

데이터 집합을 학습, 유효성 검사 및 테스트 집합으로 나눕니다. 검증 세트는 모델 피팅 중에 손실 및 모든 메트릭을 평가하는 데 사용되지만 모델이이 데이터에 적합하지 않습니다. 테스트 세트는 교육 단계에서 완전히 사용되지 않으며 모델이 새 데이터로 일반화되는 정도를 평가하기 위해 마지막에만 사용됩니다. 이것은 훈련 데이터가 부족하여 과적 합 이 중요한 관심사 인 불균형 데이터 세트에서 특히 중요합니다.

 # Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)
 

sklearn StandardScaler를 사용하여 입력 기능을 정규화하십시오. 평균이 0으로 설정되고 표준 편차가 1로 설정됩니다.

 scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)

 
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

데이터 분포를보십시오

다음으로 몇 가지 기능에 대한 긍정적 인 예와 부정적인 예의 분포를 비교하십시오. 이 시점에서 스스로에게 좋은 질문은 다음과 같습니다.

  • 이 분포가 의미가 있습니까?
    • 예. 입력을 정규화했으며 이들은 대부분 +/- 2 범위에 집중되어 있습니다.
  • 분포의 차이를 볼 수 있습니까?
    • 그렇습니다. 긍정적 인 예에는 훨씬 더 높은 극단 값이 포함되어 있습니다.
 pos_df = pd.DataFrame(train_features[ bool_train_labels], columns = train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns = train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
_ = plt.suptitle("Negative distribution")
 

png

png

모델 및 지표 정의

조밀하게 연결된 숨겨진 계층, 과적 합을 줄이기위한 드롭 아웃 계층 및 거래가 사기 될 확률을 반환하는 출력 시그 모이 드 계층으로 간단한 신경망을 만드는 함수를 정의하십시오.

 METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics = METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model
 

유용한 지표 이해

위에 정의 된 몇 가지 메트릭이 있으며 모델을 통해 계산할 수 있으며 성능을 평가할 때 도움이됩니다.

  • 오탐오탐잘못 분류 된 표본입니다.
  • 진음진양올바르게 분류 된 표본입니다.
  • 정확도 는 올바르게 분류 된 예제의 비율입니다.> $ \ frac {\ text {true samples}} {\ text {total samples}} $
  • 정밀도 는 올바르게 분류 된 예측 된 양성의 백분율> $ \ frac {\ text {true positives}} {\ text {true positives + false positives}} $
  • 리콜 은 올바르게 분류 된 실제 긍정의 백분율> $ \ frac {\ text {true positives}} {\ text {true positives + false negatives}} $
  • AUC 는 ROC-AUC (수신기 동작 특성 곡선의 곡선 아래 면적)을 나타냅니다. 이 메트릭은 분류 기가 임의의 음수 샘플보다 임의의 양수 샘플의 순위를 매길 확률과 같습니다.

더 읽어보기 :

베이스 라인 모델

모델 구축

이제 이전에 정의 된 기능을 사용하여 모델을 작성하고 학습하십시오. 모형이 기본 배치 크기보다 큰 2048을 사용하여 적합하다는 점에 유의하십시오. 이는 각 배치에 양성 샘플이 몇 개 포함될 가능성이 있는지 확인하는 데 중요합니다. 배치 크기가 너무 작 으면 배울 사기 거래가 없을 수 있습니다.

 EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
 
 model = make_model()
model.summary()
 
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________

모델을 테스트 실행하십시오.

 model.predict(train_features[:10])
 
array([[0.5788107 ],
       [0.44979692],
       [0.5427961 ],
       [0.5985188 ],
       [0.7758075 ],
       [0.3417888 ],
       [0.39359283],
       [0.5399953 ],
       [0.3551327 ],
       [0.47230086]], dtype=float32)

선택 사항 : 올바른 초기 바이어스를 설정하십시오.

이 초기 추측은 좋지 않습니다. 데이터 세트가 불균형하다는 것을 알고 있습니다. 출력 레이어의 바이어스를 설정하여이를 반영하십시오 ( 신경망 훈련을위한 레시피 : "init well"참조 ). 이것은 초기 수렴에 도움이 될 수 있습니다.

기본 바이어스 초기화로 손실은 math.log(2) = 0.69314

 results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
 
Loss: 0.7817

설정할 올바른 바이어스는 다음에서 파생 될 수 있습니다.

$$ p_0 = pos / (pos + neg) = 1 / (1 + e ^ {-b_0}) $$
$$ b_0 = -log_e (1 / p_0-1) $$
$$ b_0 = log_e (pos / neg) $$
 initial_bias = np.log([pos/neg])
initial_bias
 
array([-6.35935934])

이를 초기 편향으로 설정하면 모델이 훨씬 더 합리적인 초기 추측을 제공합니다.

가까운 거리에 있어야합니다 : pos/total = 0.0018

 model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])
 
array([[0.00093563],
       [0.00187903],
       [0.00109238],
       [0.00117128],
       [0.00134988],
       [0.00090826],
       [0.00099455],
       [0.00154405],
       [0.00100204],
       [0.0004291 ]], dtype=float32)

이 초기화로 초기 손실은 대략 다음과 같아야합니다.

$$-p_0log (p_0)-(1-p_0) log (1-p_0) = 0.01317 $$
 results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
 
Loss: 0.0146

이 초기 손실은 순진한 초기화보다 약 50 배 적습니다.

이런 식으로 모델은 긍정적 인 예가 거의 없다는 것을 배우기 위해 처음 몇 시대를 소비 할 필요가 없습니다. 또한 훈련 중 손실의 플롯을보다 쉽게 ​​읽을 수 있습니다.

초기 무게 체크 포인트

다양한 트레이닝 실행을 비교할 수 있도록이 초기 모델의 가중치를 체크 포인트 파일에 보관하고 트레이닝 전에 각 모델에로드합니다.

 initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights')
model.save_weights(initial_weights)
 

바이어스 수정이 도움이되는지 확인

계속 진행하기 전에 신중한 바이어스 초기화가 실제로 도움이되었는지 신속하게 확인하십시오.

이 신중한 초기화를 사용하거나 사용하지 않고 20 에포크 모델을 학습하고 손실을 비교하십시오.

 model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
 
 model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
 
 def plot_loss(history, label, n):
  # Use a log scale to show the wide range of values.
  plt.semilogy(history.epoch,  history.history['loss'],
               color=colors[n], label='Train '+label)
  plt.semilogy(history.epoch,  history.history['val_loss'],
          color=colors[n], label='Val '+label,
          linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  
  plt.legend()
 
 plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)
 

png

위의 그림은이를 명확하게 보여줍니다. 검증 손실 측면에서이 문제에 대해이 신중한 초기화는 명확한 이점을 제공합니다.

모델 훈련

 model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels))
 
Epoch 1/100
90/90 [==============================] - 1s 13ms/step - loss: 0.0112 - tp: 100.0000 - fp: 25.0000 - tn: 227419.0000 - fn: 301.0000 - accuracy: 0.9986 - precision: 0.8000 - recall: 0.2494 - auc: 0.7615 - val_loss: 0.0067 - val_tp: 15.0000 - val_fp: 2.0000 - val_tn: 45480.0000 - val_fn: 72.0000 - val_accuracy: 0.9984 - val_precision: 0.8824 - val_recall: 0.1724 - val_auc: 0.9077
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0075 - tp: 108.0000 - fp: 24.0000 - tn: 181938.0000 - fn: 206.0000 - accuracy: 0.9987 - precision: 0.8182 - recall: 0.3439 - auc: 0.8491 - val_loss: 0.0046 - val_tp: 45.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 42.0000 - val_accuracy: 0.9989 - val_precision: 0.8824 - val_recall: 0.5172 - val_auc: 0.9308
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0065 - tp: 138.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 176.0000 - accuracy: 0.9989 - precision: 0.8364 - recall: 0.4395 - auc: 0.8567 - val_loss: 0.0040 - val_tp: 54.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 33.0000 - val_accuracy: 0.9991 - val_precision: 0.8852 - val_recall: 0.6207 - val_auc: 0.9365
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0060 - tp: 154.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 160.0000 - accuracy: 0.9989 - precision: 0.8235 - recall: 0.4904 - auc: 0.8848 - val_loss: 0.0037 - val_tp: 61.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 26.0000 - val_accuracy: 0.9993 - val_precision: 0.8841 - val_recall: 0.7011 - val_auc: 0.9422
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0057 - tp: 157.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 157.0000 - accuracy: 0.9989 - precision: 0.8135 - recall: 0.5000 - auc: 0.8982 - val_loss: 0.0035 - val_tp: 62.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8857 - val_recall: 0.7126 - val_auc: 0.9422
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0057 - tp: 152.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 162.0000 - accuracy: 0.9989 - precision: 0.8261 - recall: 0.4841 - auc: 0.8934 - val_loss: 0.0033 - val_tp: 65.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8904 - val_recall: 0.7471 - val_auc: 0.9479
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0052 - tp: 174.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 140.0000 - accuracy: 0.9991 - precision: 0.8529 - recall: 0.5541 - auc: 0.8983 - val_loss: 0.0032 - val_tp: 66.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8919 - val_recall: 0.7586 - val_auc: 0.9479
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0054 - tp: 161.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 153.0000 - accuracy: 0.9990 - precision: 0.8342 - recall: 0.5127 - auc: 0.8983 - val_loss: 0.0031 - val_tp: 66.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8919 - val_recall: 0.7586 - val_auc: 0.9479
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 167.0000 - fp: 37.0000 - tn: 181925.0000 - fn: 147.0000 - accuracy: 0.9990 - precision: 0.8186 - recall: 0.5318 - auc: 0.9064 - val_loss: 0.0030 - val_tp: 65.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8904 - val_recall: 0.7471 - val_auc: 0.9479
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 156.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 158.0000 - accuracy: 0.9989 - precision: 0.8211 - recall: 0.4968 - auc: 0.9046 - val_loss: 0.0029 - val_tp: 67.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8933 - val_recall: 0.7701 - val_auc: 0.9479
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0048 - tp: 165.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8376 - recall: 0.5255 - auc: 0.9063 - val_loss: 0.0029 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9479
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0051 - tp: 165.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8250 - recall: 0.5255 - auc: 0.9110 - val_loss: 0.0028 - val_tp: 67.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8933 - val_recall: 0.7701 - val_auc: 0.9480
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 157.0000 - fp: 29.0000 - tn: 181933.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8441 - recall: 0.5000 - auc: 0.9031 - val_loss: 0.0028 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9479
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 160.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 154.0000 - accuracy: 0.9990 - precision: 0.8205 - recall: 0.5096 - auc: 0.8934 - val_loss: 0.0027 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9479
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0049 - tp: 168.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8235 - recall: 0.5350 - auc: 0.9031 - val_loss: 0.0027 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9479
Epoch 16/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0046 - tp: 169.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 145.0000 - accuracy: 0.9990 - precision: 0.8492 - recall: 0.5382 - auc: 0.9143 - val_loss: 0.0027 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 181.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8498 - recall: 0.5764 - auc: 0.9144 - val_loss: 0.0027 - val_tp: 70.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.8974 - val_recall: 0.8046 - val_auc: 0.9537
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 181.0000 - fp: 29.0000 - tn: 181933.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8619 - recall: 0.5764 - auc: 0.9112 - val_loss: 0.0026 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 172.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8431 - recall: 0.5478 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 20/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 177.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 137.0000 - accuracy: 0.9991 - precision: 0.8349 - recall: 0.5637 - auc: 0.9128 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 21/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 176.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8462 - recall: 0.5605 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.9167 - val_recall: 0.7586 - val_auc: 0.9537
Epoch 22/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 163.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 151.0000 - accuracy: 0.9990 - precision: 0.8316 - recall: 0.5191 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 23/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 183.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8281 - recall: 0.5828 - auc: 0.9113 - val_loss: 0.0026 - val_tp: 66.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.9041 - val_recall: 0.7586 - val_auc: 0.9537
Epoch 24/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 168.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8400 - recall: 0.5350 - auc: 0.9128 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 25/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 179.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 135.0000 - accuracy: 0.9991 - precision: 0.8483 - recall: 0.5701 - auc: 0.9161 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 26/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 173.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 141.0000 - accuracy: 0.9990 - precision: 0.8199 - recall: 0.5510 - auc: 0.9208 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 27/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 172.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8431 - recall: 0.5478 - auc: 0.9081 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 28/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 181.0000 - fp: 39.0000 - tn: 181923.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8227 - recall: 0.5764 - auc: 0.9193 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9189 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 29/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 177.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 137.0000 - accuracy: 0.9990 - precision: 0.8233 - recall: 0.5637 - auc: 0.9305 - val_loss: 0.0025 - val_tp: 67.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.9054 - val_recall: 0.7701 - val_auc: 0.9538
Epoch 30/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 168.0000 - fp: 31.0000 - tn: 181931.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8442 - recall: 0.5350 - auc: 0.9161 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 31/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 172.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8309 - recall: 0.5478 - auc: 0.9176 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 32/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 188.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8507 - recall: 0.5987 - auc: 0.9162 - val_loss: 0.0025 - val_tp: 70.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9091 - val_recall: 0.8046 - val_auc: 0.9538
Epoch 33/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 184.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8720 - recall: 0.5860 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 72.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.8276 - val_auc: 0.9537
Epoch 34/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 185.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 129.0000 - accuracy: 0.9991 - precision: 0.8486 - recall: 0.5892 - auc: 0.9273 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 35/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 178.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8318 - recall: 0.5669 - auc: 0.9160 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 36/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 171.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 143.0000 - accuracy: 0.9990 - precision: 0.8382 - recall: 0.5446 - auc: 0.9192 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 37/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 189.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 125.0000 - accuracy: 0.9991 - precision: 0.8438 - recall: 0.6019 - auc: 0.9242 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 38/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 185.0000 - fp: 25.0000 - tn: 181937.0000 - fn: 129.0000 - accuracy: 0.9992 - precision: 0.8810 - recall: 0.5892 - auc: 0.9176 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 39/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 181.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8380 - recall: 0.5764 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9189 - val_recall: 0.7816 - val_auc: 0.9538
Epoch 40/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 175.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 139.0000 - accuracy: 0.9991 - precision: 0.8537 - recall: 0.5573 - auc: 0.9209 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 41/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 180.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 134.0000 - accuracy: 0.9991 - precision: 0.8491 - recall: 0.5732 - auc: 0.9320 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 42/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 188.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8468 - recall: 0.5987 - auc: 0.9209 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 43/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 176.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8421 - recall: 0.5605 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 44/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 172.0000 - fp: 37.0000 - tn: 181925.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8230 - recall: 0.5478 - auc: 0.9129 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 45/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 175.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 139.0000 - accuracy: 0.9990 - precision: 0.8294 - recall: 0.5573 - auc: 0.9368 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 46/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 176.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8421 - recall: 0.5605 - auc: 0.9240 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 47/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 178.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8683 - recall: 0.5669 - auc: 0.9273 - val_loss: 0.0025 - val_tp: 72.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.8276 - val_auc: 0.9537
Epoch 48/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 198.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8534 - recall: 0.6306 - auc: 0.9256 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 5.0000 - val_tn: 45477.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9315 - val_recall: 0.7816 - val_auc: 0.9538
Epoch 49/100
85/90 [===========================>..] - ETA: 0s - loss: 0.0043 - tp: 162.0000 - fp: 29.0000 - tn: 173750.0000 - fn: 139.0000 - accuracy: 0.9990 - precision: 0.8482 - recall: 0.5382 - auc: 0.9157Restoring model weights from the end of the best epoch.
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 171.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 143.0000 - accuracy: 0.9991 - precision: 0.8507 - recall: 0.5446 - auc: 0.9191 - val_loss: 0.0024 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 00049: early stopping

훈련 이력 확인

이 섹션에서는 학습 및 검증 세트에서 모델의 정확성 및 손실에 대한 도표를 생성합니다. 이 옵션은 과적 합을 확인하는 데 유용하며이 학습서 에서 자세히 배울 수 있습니다.

또한 위에서 만든 모든 메트릭에 대해 이러한 플롯을 생성 할 수 있습니다. 허위 네거티브가 예로 포함되어 있습니다.

 def plot_metrics(history):
  metrics =  ['loss', 'auc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch,  history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()

 
 plot_metrics(baseline_history)
 

png

측정 항목 평가

혼동 행렬 을 사용하여 X 축이 예측 레이블이고 Y 축이 실제 레이블 인 실제 대 예측 레이블을 요약 할 수 있습니다.

 train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
 
 def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))
 

테스트 데이터 집합에서 모델을 평가하고 위에서 만든 메트릭의 결과를 표시하십시오.

 baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
 
loss :  0.002310588490217924
tp :  69.0
fp :  5.0
tn :  56866.0
fn :  22.0
accuracy :  0.9995260238647461
precision :  0.9324324131011963
recall :  0.7582417726516724
auc :  0.9557874202728271

Legitimate Transactions Detected (True Negatives):  56866
Legitimate Transactions Incorrectly Detected (False Positives):  5
Fraudulent Transactions Missed (False Negatives):  22
Fraudulent Transactions Detected (True Positives):  69
Total Fraudulent Transactions:  91

png

모형이 모든 것을 완벽하게 예측 한 경우 잘못된 예측을 나타내는 주 대각선의 값이 0이되는 대각 행렬 입니다. 이 경우 매트릭스는 오 탐지가 상대적으로 적다는 것을 보여줍니다. 이는 잘못 플래그 된 합법적 인 거래가 상대적으로 적다는 것을 의미합니다. 그러나 오 탐지 수를 늘리는 비용에도 불구하고 더 적은 오 탐지를 원할 수 있습니다. 이 부정은 부정 거래로 인해 사기 거래가 이루어질 수 있기 때문에 바람직 할 수 있지만, 부정 거래는 고객에게 카드 활동을 확인하도록 요청하는 이메일을 보낼 수 있습니다.

ROC 플롯

이제 ROC를 플로팅하십시오. 이 플롯은 출력 임계 값을 조정하여 모델이 도달 할 수있는 성능 범위를 한눈에 보여주기 때문에 유용합니다.

 def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
 
 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa50c5adef0>

png

정밀도가 비교적 높은 것처럼 보이지만 리콜 및 ROC 곡선 아래 면적 (AUC)은 원하는만큼 높지 않습니다. 분류기는 종종 정밀도와 리콜을 극대화하려고 할 때 어려움에 직면하며, 특히 불균형 데이터 세트로 작업 할 때 더욱 그렇습니다. 관심있는 문제의 맥락에서 다양한 유형의 오류 비용을 고려하는 것이 중요합니다. 이 예에서 허위 부정 (사기 거래가 누락 됨)에는 재정 비용이있을 수 있으며, 위양성 (거래가 사기로 잘못 표시됨)은 사용자의 행복을 감소시킬 수 있습니다.

클래스 가중치

클래스 가중치 계산

목표는 사기 거래를 식별하는 것이지만, 사용할 긍정적 인 샘플이 많지 않기 때문에 사용 가능한 몇 가지 예를 분류 자에게 크게 가중시키고 싶을 것입니다. 각 클래스의 Keras 가중치를 매개 변수를 통해 전달하면됩니다. 이로 인해 모델이 대표가 부족한 클래스의 예제에 "더 많은주의를 기울여야"합니다.

 # Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
 
Weight for class 0: 0.50
Weight for class 1: 289.44

수업 가중치로 모델 훈련

이제 클래스 가중치를 사용하여 모델을 재교육하고 평가하여 이것이 예측에 어떤 영향을 미치는지 확인하십시오.

 weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight) 
 
Epoch 1/100
90/90 [==============================] - 1s 15ms/step - loss: 2.5149 - tp: 105.0000 - fp: 66.0000 - tn: 238767.0000 - fn: 300.0000 - accuracy: 0.9985 - precision: 0.6140 - recall: 0.2593 - auc: 0.7803 - val_loss: 0.0067 - val_tp: 25.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 62.0000 - val_accuracy: 0.9985 - val_precision: 0.8065 - val_recall: 0.2874 - val_auc: 0.9211
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 1.2482 - tp: 145.0000 - fp: 124.0000 - tn: 181838.0000 - fn: 169.0000 - accuracy: 0.9984 - precision: 0.5390 - recall: 0.4618 - auc: 0.8560 - val_loss: 0.0062 - val_tp: 68.0000 - val_fp: 12.0000 - val_tn: 45470.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8500 - val_recall: 0.7816 - val_auc: 0.9408
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.8972 - tp: 177.0000 - fp: 237.0000 - tn: 181725.0000 - fn: 137.0000 - accuracy: 0.9979 - precision: 0.4275 - recall: 0.5637 - auc: 0.8876 - val_loss: 0.0079 - val_tp: 73.0000 - val_fp: 16.0000 - val_tn: 45466.0000 - val_fn: 14.0000 - val_accuracy: 0.9993 - val_precision: 0.8202 - val_recall: 0.8391 - val_auc: 0.9518
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6983 - tp: 210.0000 - fp: 387.0000 - tn: 181575.0000 - fn: 104.0000 - accuracy: 0.9973 - precision: 0.3518 - recall: 0.6688 - auc: 0.9028 - val_loss: 0.0098 - val_tp: 74.0000 - val_fp: 19.0000 - val_tn: 45463.0000 - val_fn: 13.0000 - val_accuracy: 0.9993 - val_precision: 0.7957 - val_recall: 0.8506 - val_auc: 0.9600
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6417 - tp: 220.0000 - fp: 583.0000 - tn: 181379.0000 - fn: 94.0000 - accuracy: 0.9963 - precision: 0.2740 - recall: 0.7006 - auc: 0.9084 - val_loss: 0.0119 - val_tp: 74.0000 - val_fp: 25.0000 - val_tn: 45457.0000 - val_fn: 13.0000 - val_accuracy: 0.9992 - val_precision: 0.7475 - val_recall: 0.8506 - val_auc: 0.9777
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5846 - tp: 232.0000 - fp: 977.0000 - tn: 180985.0000 - fn: 82.0000 - accuracy: 0.9942 - precision: 0.1919 - recall: 0.7389 - auc: 0.9048 - val_loss: 0.0148 - val_tp: 74.0000 - val_fp: 34.0000 - val_tn: 45448.0000 - val_fn: 13.0000 - val_accuracy: 0.9990 - val_precision: 0.6852 - val_recall: 0.8506 - val_auc: 0.9802
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5404 - tp: 234.0000 - fp: 1464.0000 - tn: 180498.0000 - fn: 80.0000 - accuracy: 0.9915 - precision: 0.1378 - recall: 0.7452 - auc: 0.9190 - val_loss: 0.0183 - val_tp: 74.0000 - val_fp: 50.0000 - val_tn: 45432.0000 - val_fn: 13.0000 - val_accuracy: 0.9986 - val_precision: 0.5968 - val_recall: 0.8506 - val_auc: 0.9823
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4714 - tp: 241.0000 - fp: 1862.0000 - tn: 180100.0000 - fn: 73.0000 - accuracy: 0.9894 - precision: 0.1146 - recall: 0.7675 - auc: 0.9252 - val_loss: 0.0225 - val_tp: 76.0000 - val_fp: 84.0000 - val_tn: 45398.0000 - val_fn: 11.0000 - val_accuracy: 0.9979 - val_precision: 0.4750 - val_recall: 0.8736 - val_auc: 0.9851
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4329 - tp: 247.0000 - fp: 2508.0000 - tn: 179454.0000 - fn: 67.0000 - accuracy: 0.9859 - precision: 0.0897 - recall: 0.7866 - auc: 0.9345 - val_loss: 0.0282 - val_tp: 76.0000 - val_fp: 170.0000 - val_tn: 45312.0000 - val_fn: 11.0000 - val_accuracy: 0.9960 - val_precision: 0.3089 - val_recall: 0.8736 - val_auc: 0.9873
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4467 - tp: 249.0000 - fp: 3175.0000 - tn: 178787.0000 - fn: 65.0000 - accuracy: 0.9822 - precision: 0.0727 - recall: 0.7930 - auc: 0.9210 - val_loss: 0.0341 - val_tp: 78.0000 - val_fp: 282.0000 - val_tn: 45200.0000 - val_fn: 9.0000 - val_accuracy: 0.9936 - val_precision: 0.2167 - val_recall: 0.8966 - val_auc: 0.9881
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3947 - tp: 260.0000 - fp: 3569.0000 - tn: 178393.0000 - fn: 54.0000 - accuracy: 0.9801 - precision: 0.0679 - recall: 0.8280 - auc: 0.9290 - val_loss: 0.0394 - val_tp: 78.0000 - val_fp: 346.0000 - val_tn: 45136.0000 - val_fn: 9.0000 - val_accuracy: 0.9922 - val_precision: 0.1840 - val_recall: 0.8966 - val_auc: 0.9877
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3694 - tp: 257.0000 - fp: 4294.0000 - tn: 177668.0000 - fn: 57.0000 - accuracy: 0.9761 - precision: 0.0565 - recall: 0.8185 - auc: 0.9418 - val_loss: 0.0473 - val_tp: 78.0000 - val_fp: 504.0000 - val_tn: 44978.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 0.1340 - val_recall: 0.8966 - val_auc: 0.9879
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3479 - tp: 262.0000 - fp: 4886.0000 - tn: 177076.0000 - fn: 52.0000 - accuracy: 0.9729 - precision: 0.0509 - recall: 0.8344 - auc: 0.9403 - val_loss: 0.0539 - val_tp: 78.0000 - val_fp: 586.0000 - val_tn: 44896.0000 - val_fn: 9.0000 - val_accuracy: 0.9869 - val_precision: 0.1175 - val_recall: 0.8966 - val_auc: 0.9881
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3653 - tp: 263.0000 - fp: 5360.0000 - tn: 176602.0000 - fn: 51.0000 - accuracy: 0.9703 - precision: 0.0468 - recall: 0.8376 - auc: 0.9370 - val_loss: 0.0610 - val_tp: 78.0000 - val_fp: 664.0000 - val_tn: 44818.0000 - val_fn: 9.0000 - val_accuracy: 0.9852 - val_precision: 0.1051 - val_recall: 0.8966 - val_auc: 0.9876
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3673 - tp: 262.0000 - fp: 5820.0000 - tn: 176142.0000 - fn: 52.0000 - accuracy: 0.9678 - precision: 0.0431 - recall: 0.8344 - auc: 0.9316 - val_loss: 0.0658 - val_tp: 78.0000 - val_fp: 715.0000 - val_tn: 44767.0000 - val_fn: 9.0000 - val_accuracy: 0.9841 - val_precision: 0.0984 - val_recall: 0.8966 - val_auc: 0.9877
Epoch 16/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3228 - tp: 262.0000 - fp: 6230.0000 - tn: 175732.0000 - fn: 52.0000 - accuracy: 0.9655 - precision: 0.0404 - recall: 0.8344 - auc: 0.9445 - val_loss: 0.0716 - val_tp: 79.0000 - val_fp: 805.0000 - val_tn: 44677.0000 - val_fn: 8.0000 - val_accuracy: 0.9822 - val_precision: 0.0894 - val_recall: 0.9080 - val_auc: 0.9877
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3299 - tp: 268.0000 - fp: 6572.0000 - tn: 175390.0000 - fn: 46.0000 - accuracy: 0.9637 - precision: 0.0392 - recall: 0.8535 - auc: 0.9423 - val_loss: 0.0757 - val_tp: 81.0000 - val_fp: 846.0000 - val_tn: 44636.0000 - val_fn: 6.0000 - val_accuracy: 0.9813 - val_precision: 0.0874 - val_recall: 0.9310 - val_auc: 0.9878
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2522 - tp: 276.0000 - fp: 6934.0000 - tn: 175028.0000 - fn: 38.0000 - accuracy: 0.9618 - precision: 0.0383 - recall: 0.8790 - auc: 0.9610 - val_loss: 0.0779 - val_tp: 81.0000 - val_fp: 874.0000 - val_tn: 44608.0000 - val_fn: 6.0000 - val_accuracy: 0.9807 - val_precision: 0.0848 - val_recall: 0.9310 - val_auc: 0.9877
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3607 - tp: 264.0000 - fp: 6790.0000 - tn: 175172.0000 - fn: 50.0000 - accuracy: 0.9625 - precision: 0.0374 - recall: 0.8408 - auc: 0.9303 - val_loss: 0.0781 - val_tp: 81.0000 - val_fp: 865.0000 - val_tn: 44617.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0856 - val_recall: 0.9310 - val_auc: 0.9879
Epoch 20/100
89/90 [============================>.] - ETA: 0s - loss: 0.2977 - tp: 269.0000 - fp: 6769.0000 - tn: 175189.0000 - fn: 45.0000 - accuracy: 0.9626 - precision: 0.0382 - recall: 0.8567 - auc: 0.9488Restoring model weights from the end of the best epoch.
90/90 [==============================] - 1s 6ms/step - loss: 0.2977 - tp: 269.0000 - fp: 6769.0000 - tn: 175193.0000 - fn: 45.0000 - accuracy: 0.9626 - precision: 0.0382 - recall: 0.8567 - auc: 0.9488 - val_loss: 0.0780 - val_tp: 81.0000 - val_fp: 853.0000 - val_tn: 44629.0000 - val_fn: 6.0000 - val_accuracy: 0.9811 - val_precision: 0.0867 - val_recall: 0.9310 - val_auc: 0.9879
Epoch 00020: early stopping

훈련 이력 확인

 plot_metrics(weighted_history)
 

png

측정 항목 평가

 train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
 
 weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
 
loss :  0.03226418048143387
tp :  82.0
fp :  352.0
tn :  56519.0
fn :  9.0
accuracy :  0.993662416934967
precision :  0.18894009292125702
recall :  0.901098906993866
auc :  0.9671803712844849

Legitimate Transactions Detected (True Negatives):  56519
Legitimate Transactions Incorrectly Detected (False Positives):  352
Fraudulent Transactions Missed (False Negatives):  9
Fraudulent Transactions Detected (True Positives):  82
Total Fraudulent Transactions:  91

png

여기에서 클래스 가중치를 사용하면 오 탐지가 더 많기 때문에 정확도와 정밀도가 낮아 지지만 반대로 모델이 더 많은 포지티브를 찾았 기 때문에 리콜 및 AUC가 더 높습니다. 정확도가 낮음에도 불구하고이 모델은 회수율이 높으며 사기 거래가 더 많습니다. 물론 두 가지 유형의 오류에는 비용이 발생합니다 (합법적 인 거래를 너무 많은 사기로 신고하여 사용자를 버그로 만들고 싶지는 않습니다). 응용 프로그램에 대한 이러한 다른 유형의 오류 간의 상충 관계를 신중하게 고려하십시오.

ROC 플롯

 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa54c0729e8>

png

오버 샘플링

소수 클래스 오버 샘플링

관련 접근 방식은 소수 클래스를 오버 샘플링하여 데이터 세트를 다시 샘플링하는 것입니다.

 pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
 

NumPy 사용

긍정적 인 예에서 올바른 수의 랜덤 인덱스를 선택하여 수동으로 데이터 집합의 균형을 맞출 수 있습니다.

 ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
 
(181962, 29)
 resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
 
(363924, 29)

tf.data 사용

tf.data 사용하는 경우 균형 잡힌 예제를 생성하는 가장 쉬운 방법은 positivenegative 데이터 세트로 시작하여 병합하는 것입니다. 더 많은 예 는 tf.data 안내서 를 참조하십시오.

 BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)
 

각 데이터 세트는 (feature, label) 쌍을 제공합니다.

 for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
 
Features:
 [ 0.23104754  0.83661044 -0.31875356  1.9796369   1.28403692  0.07389102
  1.03350673 -0.11568355 -1.54396817  0.88004244 -1.66944551 -0.24324391
  0.45900013  0.14583622 -2.06637388  0.42470592 -0.94489216 -0.83112221
 -1.83416278 -0.34138858  0.14130878  0.51019975  0.08224586  0.6642136
 -1.39031637 -0.42194185  0.22525572  0.28277796 -4.86369823]

Label:  1

experimental.sample_from_datasets 사용하여 둘을 병합하십시오.

 resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
 
 for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
 
0.49609375

이 데이터 세트를 사용하려면 에포크 당 단계 수가 필요합니다.

이 경우 "에포크"의 정의는 명확하지 않습니다. 각 부정적인 예를 한 번 보는 데 필요한 배치 수라고 가정하십시오.

 resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
 
278.0

오버 샘플링 된 데이터에 대한 교육

이제 클래스 가중치를 사용하는 대신 리샘플링 된 데이터 세트로 모델을 학습하여 이러한 방법을 비교하십시오.

 resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks = [early_stopping],
    validation_data=val_ds)
 
Epoch 1/100
278/278 [==============================] - 6s 23ms/step - loss: 0.4356 - tp: 223484.0000 - fp: 51288.0000 - tn: 290777.0000 - fn: 60757.0000 - accuracy: 0.8211 - precision: 0.8133 - recall: 0.7862 - auc: 0.8933 - val_loss: 0.2172 - val_tp: 79.0000 - val_fp: 1076.0000 - val_tn: 44406.0000 - val_fn: 8.0000 - val_accuracy: 0.9762 - val_precision: 0.0684 - val_recall: 0.9080 - val_auc: 0.9792
Epoch 2/100
278/278 [==============================] - 6s 20ms/step - loss: 0.2177 - tp: 246785.0000 - fp: 12557.0000 - tn: 271871.0000 - fn: 38131.0000 - accuracy: 0.9110 - precision: 0.9516 - recall: 0.8662 - auc: 0.9686 - val_loss: 0.1226 - val_tp: 80.0000 - val_fp: 951.0000 - val_tn: 44531.0000 - val_fn: 7.0000 - val_accuracy: 0.9790 - val_precision: 0.0776 - val_recall: 0.9195 - val_auc: 0.9835
Epoch 3/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1751 - tp: 250631.0000 - fp: 9797.0000 - tn: 275174.0000 - fn: 33742.0000 - accuracy: 0.9235 - precision: 0.9624 - recall: 0.8813 - auc: 0.9810 - val_loss: 0.0940 - val_tp: 82.0000 - val_fp: 966.0000 - val_tn: 44516.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0782 - val_recall: 0.9425 - val_auc: 0.9836
Epoch 4/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1532 - tp: 254169.0000 - fp: 9171.0000 - tn: 275694.0000 - fn: 30310.0000 - accuracy: 0.9307 - precision: 0.9652 - recall: 0.8935 - auc: 0.9861 - val_loss: 0.0802 - val_tp: 82.0000 - val_fp: 918.0000 - val_tn: 44564.0000 - val_fn: 5.0000 - val_accuracy: 0.9797 - val_precision: 0.0820 - val_recall: 0.9425 - val_auc: 0.9847
Epoch 5/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1372 - tp: 257034.0000 - fp: 9061.0000 - tn: 275758.0000 - fn: 27491.0000 - accuracy: 0.9358 - precision: 0.9659 - recall: 0.9034 - auc: 0.9892 - val_loss: 0.0720 - val_tp: 82.0000 - val_fp: 910.0000 - val_tn: 44572.0000 - val_fn: 5.0000 - val_accuracy: 0.9799 - val_precision: 0.0827 - val_recall: 0.9425 - val_auc: 0.9854
Epoch 6/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1260 - tp: 258997.0000 - fp: 9079.0000 - tn: 275819.0000 - fn: 25449.0000 - accuracy: 0.9394 - precision: 0.9661 - recall: 0.9105 - auc: 0.9911 - val_loss: 0.0666 - val_tp: 81.0000 - val_fp: 915.0000 - val_tn: 44567.0000 - val_fn: 6.0000 - val_accuracy: 0.9798 - val_precision: 0.0813 - val_recall: 0.9310 - val_auc: 0.9856
Epoch 7/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1167 - tp: 261100.0000 - fp: 9112.0000 - tn: 276180.0000 - fn: 22952.0000 - accuracy: 0.9437 - precision: 0.9663 - recall: 0.9192 - auc: 0.9925 - val_loss: 0.0623 - val_tp: 81.0000 - val_fp: 911.0000 - val_tn: 44571.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0817 - val_recall: 0.9310 - val_auc: 0.9858
Epoch 8/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1082 - tp: 263945.0000 - fp: 9428.0000 - tn: 275276.0000 - fn: 20695.0000 - accuracy: 0.9471 - precision: 0.9655 - recall: 0.9273 - auc: 0.9937 - val_loss: 0.0587 - val_tp: 81.0000 - val_fp: 910.0000 - val_tn: 44572.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0817 - val_recall: 0.9310 - val_auc: 0.9857
Epoch 9/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1014 - tp: 268108.0000 - fp: 10376.0000 - tn: 274312.0000 - fn: 16548.0000 - accuracy: 0.9527 - precision: 0.9627 - recall: 0.9419 - auc: 0.9944 - val_loss: 0.0543 - val_tp: 80.0000 - val_fp: 873.0000 - val_tn: 44609.0000 - val_fn: 7.0000 - val_accuracy: 0.9807 - val_precision: 0.0839 - val_recall: 0.9195 - val_auc: 0.9857
Epoch 10/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0951 - tp: 277520.0000 - fp: 12692.0000 - tn: 271795.0000 - fn: 7337.0000 - accuracy: 0.9648 - precision: 0.9563 - recall: 0.9742 - auc: 0.9950 - val_loss: 0.0495 - val_tp: 79.0000 - val_fp: 829.0000 - val_tn: 44653.0000 - val_fn: 8.0000 - val_accuracy: 0.9816 - val_precision: 0.0870 - val_recall: 0.9080 - val_auc: 0.9855
Epoch 11/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0895 - tp: 278865.0000 - fp: 12938.0000 - tn: 271719.0000 - fn: 5822.0000 - accuracy: 0.9670 - precision: 0.9557 - recall: 0.9795 - auc: 0.9955 - val_loss: 0.0450 - val_tp: 79.0000 - val_fp: 789.0000 - val_tn: 44693.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0910 - val_recall: 0.9080 - val_auc: 0.9859
Epoch 12/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0842 - tp: 279845.0000 - fp: 13187.0000 - tn: 272121.0000 - fn: 4191.0000 - accuracy: 0.9695 - precision: 0.9550 - recall: 0.9852 - auc: 0.9960 - val_loss: 0.0410 - val_tp: 79.0000 - val_fp: 733.0000 - val_tn: 44749.0000 - val_fn: 8.0000 - val_accuracy: 0.9837 - val_precision: 0.0973 - val_recall: 0.9080 - val_auc: 0.9813
Epoch 13/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0792 - tp: 281765.0000 - fp: 12977.0000 - tn: 271393.0000 - fn: 3209.0000 - accuracy: 0.9716 - precision: 0.9560 - recall: 0.9887 - auc: 0.9963 - val_loss: 0.0389 - val_tp: 79.0000 - val_fp: 721.0000 - val_tn: 44761.0000 - val_fn: 8.0000 - val_accuracy: 0.9840 - val_precision: 0.0988 - val_recall: 0.9080 - val_auc: 0.9814
Epoch 14/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0754 - tp: 281962.0000 - fp: 13026.0000 - tn: 272154.0000 - fn: 2202.0000 - accuracy: 0.9733 - precision: 0.9558 - recall: 0.9923 - auc: 0.9966 - val_loss: 0.0348 - val_tp: 79.0000 - val_fp: 646.0000 - val_tn: 44836.0000 - val_fn: 8.0000 - val_accuracy: 0.9856 - val_precision: 0.1090 - val_recall: 0.9080 - val_auc: 0.9763
Epoch 15/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0722 - tp: 283858.0000 - fp: 12932.0000 - tn: 271419.0000 - fn: 1135.0000 - accuracy: 0.9753 - precision: 0.9564 - recall: 0.9960 - auc: 0.9967 - val_loss: 0.0331 - val_tp: 79.0000 - val_fp: 640.0000 - val_tn: 44842.0000 - val_fn: 8.0000 - val_accuracy: 0.9858 - val_precision: 0.1099 - val_recall: 0.9080 - val_auc: 0.9714
Epoch 16/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0689 - tp: 283059.0000 - fp: 12757.0000 - tn: 273004.0000 - fn: 524.0000 - accuracy: 0.9767 - precision: 0.9569 - recall: 0.9982 - auc: 0.9970 - val_loss: 0.0308 - val_tp: 79.0000 - val_fp: 583.0000 - val_tn: 44899.0000 - val_fn: 8.0000 - val_accuracy: 0.9870 - val_precision: 0.1193 - val_recall: 0.9080 - val_auc: 0.9667
Epoch 17/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0661 - tp: 283879.0000 - fp: 12340.0000 - tn: 272779.0000 - fn: 346.0000 - accuracy: 0.9777 - precision: 0.9583 - recall: 0.9988 - auc: 0.9971 - val_loss: 0.0289 - val_tp: 79.0000 - val_fp: 542.0000 - val_tn: 44940.0000 - val_fn: 8.0000 - val_accuracy: 0.9879 - val_precision: 0.1272 - val_recall: 0.9080 - val_auc: 0.9618
Epoch 18/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0635 - tp: 284858.0000 - fp: 12157.0000 - tn: 272120.0000 - fn: 209.0000 - accuracy: 0.9783 - precision: 0.9591 - recall: 0.9993 - auc: 0.9973 - val_loss: 0.0277 - val_tp: 79.0000 - val_fp: 511.0000 - val_tn: 44971.0000 - val_fn: 8.0000 - val_accuracy: 0.9886 - val_precision: 0.1339 - val_recall: 0.9080 - val_auc: 0.9621
Epoch 19/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0620 - tp: 284459.0000 - fp: 11978.0000 - tn: 272718.0000 - fn: 189.0000 - accuracy: 0.9786 - precision: 0.9596 - recall: 0.9993 - auc: 0.9973 - val_loss: 0.0261 - val_tp: 79.0000 - val_fp: 478.0000 - val_tn: 45004.0000 - val_fn: 8.0000 - val_accuracy: 0.9893 - val_precision: 0.1418 - val_recall: 0.9080 - val_auc: 0.9624
Epoch 20/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0600 - tp: 284950.0000 - fp: 11793.0000 - tn: 272572.0000 - fn: 29.0000 - accuracy: 0.9792 - precision: 0.9603 - recall: 0.9999 - auc: 0.9974 - val_loss: 0.0252 - val_tp: 79.0000 - val_fp: 463.0000 - val_tn: 45019.0000 - val_fn: 8.0000 - val_accuracy: 0.9897 - val_precision: 0.1458 - val_recall: 0.9080 - val_auc: 0.9626
Epoch 21/100
276/278 [============================>.] - ETA: 0s - loss: 0.0581 - tp: 282210.0000 - fp: 11270.0000 - tn: 271768.0000 - fn: 0.0000e+00 - accuracy: 0.9801 - precision: 0.9616 - recall: 1.0000 - auc: 0.9975Restoring model weights from the end of the best epoch.
278/278 [==============================] - 6s 22ms/step - loss: 0.0581 - tp: 284274.0000 - fp: 11360.0000 - tn: 273710.0000 - fn: 0.0000e+00 - accuracy: 0.9800 - precision: 0.9616 - recall: 1.0000 - auc: 0.9975 - val_loss: 0.0241 - val_tp: 79.0000 - val_fp: 444.0000 - val_tn: 45038.0000 - val_fn: 8.0000 - val_accuracy: 0.9901 - val_precision: 0.1511 - val_recall: 0.9080 - val_auc: 0.9628
Epoch 00021: early stopping

트레이닝 프로세스가 각 그라디언트 업데이트에서 전체 데이터 세트를 고려하는 경우이 오버 샘플링은 기본적으로 클래스 가중치와 동일합니다.

그러나 모델을 배치 방식으로 학습 할 때 여기서와 같이 오버 샘플링 된 데이터는보다 부드러운 기울기 신호를 제공합니다. 각 긍정적 인 예제가 큰 가중치를 가진 하나의 배치로 표시되는 대신 매번 다른 배치로 표시됩니다. 작은 무게.

이 부드러운 그라디언트 신호는 모델 훈련을 더 쉽게 만듭니다.

훈련 이력 확인

교육 데이터는 유효성 검사 및 테스트 데이터와 완전히 다른 분포이므로 메트릭 분포는 여기에서 다릅니다.

 plot_metrics(resampled_history )
 

png

재 훈련

균형 잡힌 데이터에 대한 교육이 더 쉬우므로 위의 교육 절차가 빠르게 초과 될 수 있습니다.

따라서 에포크를 해체하여 callbacks.EarlyStopping 을 제공합니다.

 resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch = 20,
    epochs=10*EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_ds))
 
Epoch 1/1000
20/20 [==============================] - 1s 60ms/step - loss: 1.0656 - tp: 9507.0000 - fp: 7370.0000 - tn: 58667.0000 - fn: 10985.0000 - accuracy: 0.7879 - precision: 0.5633 - recall: 0.4639 - auc: 0.8255 - val_loss: 0.5792 - val_tp: 66.0000 - val_fp: 13452.0000 - val_tn: 32030.0000 - val_fn: 21.0000 - val_accuracy: 0.7043 - val_precision: 0.0049 - val_recall: 0.7586 - val_auc: 0.7866
Epoch 2/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.6996 - tp: 13383.0000 - fp: 7208.0000 - tn: 13397.0000 - fn: 6972.0000 - accuracy: 0.6538 - precision: 0.6499 - recall: 0.6575 - auc: 0.7027 - val_loss: 0.5702 - val_tp: 76.0000 - val_fp: 12408.0000 - val_tn: 33074.0000 - val_fn: 11.0000 - val_accuracy: 0.7275 - val_precision: 0.0061 - val_recall: 0.8736 - val_auc: 0.9076
Epoch 3/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.5532 - tp: 15127.0000 - fp: 6665.0000 - tn: 14055.0000 - fn: 5113.0000 - accuracy: 0.7125 - precision: 0.6942 - recall: 0.7474 - auc: 0.7952 - val_loss: 0.5335 - val_tp: 79.0000 - val_fp: 9006.0000 - val_tn: 36476.0000 - val_fn: 8.0000 - val_accuracy: 0.8022 - val_precision: 0.0087 - val_recall: 0.9080 - val_auc: 0.9408
Epoch 4/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.4738 - tp: 16061.0000 - fp: 5669.0000 - tn: 14890.0000 - fn: 4340.0000 - accuracy: 0.7556 - precision: 0.7391 - recall: 0.7873 - auc: 0.8495 - val_loss: 0.4883 - val_tp: 78.0000 - val_fp: 5756.0000 - val_tn: 39726.0000 - val_fn: 9.0000 - val_accuracy: 0.8735 - val_precision: 0.0134 - val_recall: 0.8966 - val_auc: 0.9489
Epoch 5/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.4266 - tp: 16612.0000 - fp: 4719.0000 - tn: 15715.0000 - fn: 3914.0000 - accuracy: 0.7892 - precision: 0.7788 - recall: 0.8093 - auc: 0.8786 - val_loss: 0.4435 - val_tp: 78.0000 - val_fp: 3758.0000 - val_tn: 41724.0000 - val_fn: 9.0000 - val_accuracy: 0.9173 - val_precision: 0.0203 - val_recall: 0.8966 - val_auc: 0.9539
Epoch 6/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.3908 - tp: 16911.0000 - fp: 3861.0000 - tn: 16514.0000 - fn: 3674.0000 - accuracy: 0.8160 - precision: 0.8141 - recall: 0.8215 - auc: 0.8976 - val_loss: 0.4032 - val_tp: 79.0000 - val_fp: 2770.0000 - val_tn: 42712.0000 - val_fn: 8.0000 - val_accuracy: 0.9390 - val_precision: 0.0277 - val_recall: 0.9080 - val_auc: 0.9590
Epoch 7/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3664 - tp: 17049.0000 - fp: 3209.0000 - tn: 17179.0000 - fn: 3523.0000 - accuracy: 0.8356 - precision: 0.8416 - recall: 0.8287 - auc: 0.9108 - val_loss: 0.3682 - val_tp: 79.0000 - val_fp: 2119.0000 - val_tn: 43363.0000 - val_fn: 8.0000 - val_accuracy: 0.9533 - val_precision: 0.0359 - val_recall: 0.9080 - val_auc: 0.9634
Epoch 8/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.3467 - tp: 17100.0000 - fp: 2699.0000 - tn: 17686.0000 - fn: 3475.0000 - accuracy: 0.8493 - precision: 0.8637 - recall: 0.8311 - auc: 0.9193 - val_loss: 0.3373 - val_tp: 79.0000 - val_fp: 1753.0000 - val_tn: 43729.0000 - val_fn: 8.0000 - val_accuracy: 0.9614 - val_precision: 0.0431 - val_recall: 0.9080 - val_auc: 0.9675
Epoch 9/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.3285 - tp: 17043.0000 - fp: 2345.0000 - tn: 18228.0000 - fn: 3344.0000 - accuracy: 0.8611 - precision: 0.8790 - recall: 0.8360 - auc: 0.9271 - val_loss: 0.3104 - val_tp: 79.0000 - val_fp: 1495.0000 - val_tn: 43987.0000 - val_fn: 8.0000 - val_accuracy: 0.9670 - val_precision: 0.0502 - val_recall: 0.9080 - val_auc: 0.9702
Epoch 10/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.3094 - tp: 17322.0000 - fp: 2012.0000 - tn: 18405.0000 - fn: 3221.0000 - accuracy: 0.8722 - precision: 0.8959 - recall: 0.8432 - auc: 0.9361 - val_loss: 0.2865 - val_tp: 79.0000 - val_fp: 1332.0000 - val_tn: 44150.0000 - val_fn: 8.0000 - val_accuracy: 0.9706 - val_precision: 0.0560 - val_recall: 0.9080 - val_auc: 0.9721
Epoch 11/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2962 - tp: 17184.0000 - fp: 1757.0000 - tn: 18853.0000 - fn: 3166.0000 - accuracy: 0.8798 - precision: 0.9072 - recall: 0.8444 - auc: 0.9406 - val_loss: 0.2654 - val_tp: 79.0000 - val_fp: 1228.0000 - val_tn: 44254.0000 - val_fn: 8.0000 - val_accuracy: 0.9729 - val_precision: 0.0604 - val_recall: 0.9080 - val_auc: 0.9739
Epoch 12/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2835 - tp: 17373.0000 - fp: 1543.0000 - tn: 18909.0000 - fn: 3135.0000 - accuracy: 0.8858 - precision: 0.9184 - recall: 0.8471 - auc: 0.9458 - val_loss: 0.2469 - val_tp: 79.0000 - val_fp: 1155.0000 - val_tn: 44327.0000 - val_fn: 8.0000 - val_accuracy: 0.9745 - val_precision: 0.0640 - val_recall: 0.9080 - val_auc: 0.9759
Epoch 13/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2710 - tp: 17386.0000 - fp: 1395.0000 - tn: 19124.0000 - fn: 3055.0000 - accuracy: 0.8914 - precision: 0.9257 - recall: 0.8505 - auc: 0.9502 - val_loss: 0.2302 - val_tp: 79.0000 - val_fp: 1092.0000 - val_tn: 44390.0000 - val_fn: 8.0000 - val_accuracy: 0.9759 - val_precision: 0.0675 - val_recall: 0.9080 - val_auc: 0.9782
Epoch 14/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2618 - tp: 17336.0000 - fp: 1343.0000 - tn: 19296.0000 - fn: 2985.0000 - accuracy: 0.8943 - precision: 0.9281 - recall: 0.8531 - auc: 0.9541 - val_loss: 0.2156 - val_tp: 79.0000 - val_fp: 1053.0000 - val_tn: 44429.0000 - val_fn: 8.0000 - val_accuracy: 0.9767 - val_precision: 0.0698 - val_recall: 0.9080 - val_auc: 0.9797
Epoch 15/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2529 - tp: 17466.0000 - fp: 1154.0000 - tn: 19366.0000 - fn: 2974.0000 - accuracy: 0.8992 - precision: 0.9380 - recall: 0.8545 - auc: 0.9574 - val_loss: 0.2026 - val_tp: 79.0000 - val_fp: 1029.0000 - val_tn: 44453.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0713 - val_recall: 0.9080 - val_auc: 0.9806
Epoch 16/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2456 - tp: 17579.0000 - fp: 1075.0000 - tn: 19322.0000 - fn: 2984.0000 - accuracy: 0.9009 - precision: 0.9424 - recall: 0.8549 - auc: 0.9590 - val_loss: 0.1923 - val_tp: 79.0000 - val_fp: 1017.0000 - val_tn: 44465.0000 - val_fn: 8.0000 - val_accuracy: 0.9775 - val_precision: 0.0721 - val_recall: 0.9080 - val_auc: 0.9813
Epoch 17/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2382 - tp: 17573.0000 - fp: 982.0000 - tn: 19540.0000 - fn: 2865.0000 - accuracy: 0.9061 - precision: 0.9471 - recall: 0.8598 - auc: 0.9620 - val_loss: 0.1828 - val_tp: 79.0000 - val_fp: 1005.0000 - val_tn: 44477.0000 - val_fn: 8.0000 - val_accuracy: 0.9778 - val_precision: 0.0729 - val_recall: 0.9080 - val_auc: 0.9819
Epoch 18/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2307 - tp: 17711.0000 - fp: 966.0000 - tn: 19448.0000 - fn: 2835.0000 - accuracy: 0.9072 - precision: 0.9483 - recall: 0.8620 - auc: 0.9644 - val_loss: 0.1736 - val_tp: 80.0000 - val_fp: 990.0000 - val_tn: 44492.0000 - val_fn: 7.0000 - val_accuracy: 0.9781 - val_precision: 0.0748 - val_recall: 0.9195 - val_auc: 0.9825
Epoch 19/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2280 - tp: 17732.0000 - fp: 952.0000 - tn: 19442.0000 - fn: 2834.0000 - accuracy: 0.9076 - precision: 0.9490 - recall: 0.8622 - auc: 0.9653 - val_loss: 0.1660 - val_tp: 80.0000 - val_fp: 974.0000 - val_tn: 44508.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0759 - val_recall: 0.9195 - val_auc: 0.9826
Epoch 20/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2224 - tp: 17725.0000 - fp: 939.0000 - tn: 19538.0000 - fn: 2758.0000 - accuracy: 0.9097 - precision: 0.9497 - recall: 0.8654 - auc: 0.9667 - val_loss: 0.1591 - val_tp: 80.0000 - val_fp: 962.0000 - val_tn: 44520.0000 - val_fn: 7.0000 - val_accuracy: 0.9787 - val_precision: 0.0768 - val_recall: 0.9195 - val_auc: 0.9831
Epoch 21/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2168 - tp: 17757.0000 - fp: 826.0000 - tn: 19618.0000 - fn: 2759.0000 - accuracy: 0.9125 - precision: 0.9556 - recall: 0.8655 - auc: 0.9689 - val_loss: 0.1531 - val_tp: 80.0000 - val_fp: 967.0000 - val_tn: 44515.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0764 - val_recall: 0.9195 - val_auc: 0.9831
Epoch 22/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2112 - tp: 17833.0000 - fp: 883.0000 - tn: 19522.0000 - fn: 2722.0000 - accuracy: 0.9120 - precision: 0.9528 - recall: 0.8676 - auc: 0.9703 - val_loss: 0.1479 - val_tp: 80.0000 - val_fp: 975.0000 - val_tn: 44507.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0758 - val_recall: 0.9195 - val_auc: 0.9832
Epoch 23/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2058 - tp: 17865.0000 - fp: 835.0000 - tn: 19580.0000 - fn: 2680.0000 - accuracy: 0.9142 - precision: 0.9553 - recall: 0.8696 - auc: 0.9723 - val_loss: 0.1427 - val_tp: 80.0000 - val_fp: 977.0000 - val_tn: 44505.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0757 - val_recall: 0.9195 - val_auc: 0.9834
Epoch 24/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2053 - tp: 17856.0000 - fp: 802.0000 - tn: 19599.0000 - fn: 2703.0000 - accuracy: 0.9144 - precision: 0.9570 - recall: 0.8685 - auc: 0.9727 - val_loss: 0.1375 - val_tp: 80.0000 - val_fp: 969.0000 - val_tn: 44513.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0763 - val_recall: 0.9195 - val_auc: 0.9833
Epoch 25/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2004 - tp: 17854.0000 - fp: 809.0000 - tn: 19690.0000 - fn: 2607.0000 - accuracy: 0.9166 - precision: 0.9567 - recall: 0.8726 - auc: 0.9740 - val_loss: 0.1331 - val_tp: 80.0000 - val_fp: 976.0000 - val_tn: 44506.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0758 - val_recall: 0.9195 - val_auc: 0.9837
Epoch 26/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1991 - tp: 17857.0000 - fp: 793.0000 - tn: 19690.0000 - fn: 2620.0000 - accuracy: 0.9167 - precision: 0.9575 - recall: 0.8721 - auc: 0.9747 - val_loss: 0.1291 - val_tp: 80.0000 - val_fp: 968.0000 - val_tn: 44514.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0763 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 27/1000
20/20 [==============================] - 1s 40ms/step - loss: 0.1929 - tp: 17836.0000 - fp: 750.0000 - tn: 19833.0000 - fn: 2541.0000 - accuracy: 0.9197 - precision: 0.9596 - recall: 0.8753 - auc: 0.9760 - val_loss: 0.1252 - val_tp: 80.0000 - val_fp: 960.0000 - val_tn: 44522.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0769 - val_recall: 0.9195 - val_auc: 0.9839
Epoch 28/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.1935 - tp: 17776.0000 - fp: 753.0000 - tn: 19827.0000 - fn: 2604.0000 - accuracy: 0.9180 - precision: 0.9594 - recall: 0.8722 - auc: 0.9763 - val_loss: 0.1215 - val_tp: 80.0000 - val_fp: 946.0000 - val_tn: 44536.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0780 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 29/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1892 - tp: 17877.0000 - fp: 746.0000 - tn: 19791.0000 - fn: 2546.0000 - accuracy: 0.9196 - precision: 0.9599 - recall: 0.8753 - auc: 0.9773 - val_loss: 0.1183 - val_tp: 80.0000 - val_fp: 944.0000 - val_tn: 44538.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0781 - val_recall: 0.9195 - val_auc: 0.9840
Epoch 30/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1855 - tp: 18053.0000 - fp: 746.0000 - tn: 19673.0000 - fn: 2488.0000 - accuracy: 0.9210 - precision: 0.9603 - recall: 0.8789 - auc: 0.9779 - val_loss: 0.1157 - val_tp: 80.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 7.0000 - val_accuracy: 0.9790 - val_precision: 0.0777 - val_recall: 0.9195 - val_auc: 0.9835
Epoch 31/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1843 - tp: 18042.0000 - fp: 723.0000 - tn: 19656.0000 - fn: 2539.0000 - accuracy: 0.9204 - precision: 0.9615 - recall: 0.8766 - auc: 0.9783 - val_loss: 0.1137 - val_tp: 80.0000 - val_fp: 958.0000 - val_tn: 44524.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0771 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 32/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1831 - tp: 17974.0000 - fp: 743.0000 - tn: 19741.0000 - fn: 2502.0000 - accuracy: 0.9208 - precision: 0.9603 - recall: 0.8778 - auc: 0.9789 - val_loss: 0.1112 - val_tp: 80.0000 - val_fp: 958.0000 - val_tn: 44524.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0771 - val_recall: 0.9195 - val_auc: 0.9840
Epoch 33/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1805 - tp: 18172.0000 - fp: 775.0000 - tn: 19591.0000 - fn: 2422.0000 - accuracy: 0.9219 - precision: 0.9591 - recall: 0.8824 - auc: 0.9796 - val_loss: 0.1088 - val_tp: 81.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 6.0000 - val_accuracy: 0.9789 - val_precision: 0.0781 - val_recall: 0.9310 - val_auc: 0.9841
Epoch 34/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1749 - tp: 18125.0000 - fp: 715.0000 - tn: 19698.0000 - fn: 2422.0000 - accuracy: 0.9234 - precision: 0.9620 - recall: 0.8821 - auc: 0.9812 - val_loss: 0.1068 - val_tp: 81.0000 - val_fp: 964.0000 - val_tn: 44518.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0775 - val_recall: 0.9310 - val_auc: 0.9836
Epoch 35/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1769 - tp: 18135.0000 - fp: 715.0000 - tn: 19694.0000 - fn: 2416.0000 - accuracy: 0.9236 - precision: 0.9621 - recall: 0.8824 - auc: 0.9809 - val_loss: 0.1048 - val_tp: 81.0000 - val_fp: 978.0000 - val_tn: 44504.0000 - val_fn: 6.0000 - val_accuracy: 0.9784 - val_precision: 0.0765 - val_recall: 0.9310 - val_auc: 0.9838
Epoch 36/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1739 - tp: 18006.0000 - fp: 704.0000 - tn: 19827.0000 - fn: 2423.0000 - accuracy: 0.9237 - precision: 0.9624 - recall: 0.8814 - auc: 0.9814 - val_loss: 0.1029 - val_tp: 81.0000 - val_fp: 986.0000 - val_tn: 44496.0000 - val_fn: 6.0000 - val_accuracy: 0.9782 - val_precision: 0.0759 - val_recall: 0.9310 - val_auc: 0.9839
Epoch 37/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1687 - tp: 18002.0000 - fp: 660.0000 - tn: 19879.0000 - fn: 2419.0000 - accuracy: 0.9248 - precision: 0.9646 - recall: 0.8815 - auc: 0.9826 - val_loss: 0.1011 - val_tp: 81.0000 - val_fp: 984.0000 - val_tn: 44498.0000 - val_fn: 6.0000 - val_accuracy: 0.9783 - val_precision: 0.0761 - val_recall: 0.9310 - val_auc: 0.9841
Epoch 38/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1699 - tp: 17932.0000 - fp: 677.0000 - tn: 19986.0000 - fn: 2365.0000 - accuracy: 0.9257 - precision: 0.9636 - recall: 0.8835 - auc: 0.9825 - val_loss: 0.0995 - val_tp: 82.0000 - val_fp: 979.0000 - val_tn: 44503.0000 - val_fn: 5.0000 - val_accuracy: 0.9784 - val_precision: 0.0773 - val_recall: 0.9425 - val_auc: 0.9842
Epoch 39/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1676 - tp: 18086.0000 - fp: 736.0000 - tn: 19780.0000 - fn: 2358.0000 - accuracy: 0.9245 - precision: 0.9609 - recall: 0.8847 - auc: 0.9826 - val_loss: 0.0980 - val_tp: 82.0000 - val_fp: 975.0000 - val_tn: 44507.0000 - val_fn: 5.0000 - val_accuracy: 0.9785 - val_precision: 0.0776 - val_recall: 0.9425 - val_auc: 0.9844
Epoch 40/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1670 - tp: 18066.0000 - fp: 685.0000 - tn: 19868.0000 - fn: 2341.0000 - accuracy: 0.9261 - precision: 0.9635 - recall: 0.8853 - auc: 0.9832 - val_loss: 0.0964 - val_tp: 82.0000 - val_fp: 965.0000 - val_tn: 44517.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0783 - val_recall: 0.9425 - val_auc: 0.9845
Epoch 41/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1640 - tp: 17950.0000 - fp: 645.0000 - tn: 19995.0000 - fn: 2370.0000 - accuracy: 0.9264 - precision: 0.9653 - recall: 0.8834 - auc: 0.9839 - val_loss: 0.0950 - val_tp: 82.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0790 - val_recall: 0.9425 - val_auc: 0.9835
Epoch 42/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1641 - tp: 18083.0000 - fp: 665.0000 - tn: 19842.0000 - fn: 2370.0000 - accuracy: 0.9259 - precision: 0.9645 - recall: 0.8841 - auc: 0.9839 - val_loss: 0.0938 - val_tp: 82.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0795 - val_recall: 0.9425 - val_auc: 0.9837
Epoch 43/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1600 - tp: 18012.0000 - fp: 684.0000 - tn: 19970.0000 - fn: 2294.0000 - accuracy: 0.9273 - precision: 0.9634 - recall: 0.8870 - auc: 0.9845 - val_loss: 0.0925 - val_tp: 82.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0795 - val_recall: 0.9425 - val_auc: 0.9837
Epoch 44/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1597 - tp: 18346.0000 - fp: 657.0000 - tn: 19657.0000 - fn: 2300.0000 - accuracy: 0.9278 - precision: 0.9654 - recall: 0.8886 - auc: 0.9847 - val_loss: 0.0919 - val_tp: 82.0000 - val_fp: 955.0000 - val_tn: 44527.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0791 - val_recall: 0.9425 - val_auc: 0.9838
Epoch 45/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1607 - tp: 18109.0000 - fp: 726.0000 - tn: 19836.0000 - fn: 2289.0000 - accuracy: 0.9264 - precision: 0.9615 - recall: 0.8878 - auc: 0.9846 - val_loss: 0.0908 - val_tp: 82.0000 - val_fp: 948.0000 - val_tn: 44534.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0796 - val_recall: 0.9425 - val_auc: 0.9839
Epoch 46/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1581 - tp: 18192.0000 - fp: 650.0000 - tn: 19833.0000 - fn: 2285.0000 - accuracy: 0.9283 - precision: 0.9655 - recall: 0.8884 - auc: 0.9849 - val_loss: 0.0902 - val_tp: 82.0000 - val_fp: 955.0000 - val_tn: 44527.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0791 - val_recall: 0.9425 - val_auc: 0.9839
Epoch 47/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1579 - tp: 18301.0000 - fp: 676.0000 - tn: 19760.0000 - fn: 2223.0000 - accuracy: 0.9292 - precision: 0.9644 - recall: 0.8917 - auc: 0.9853 - val_loss: 0.0892 - val_tp: 82.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0790 - val_recall: 0.9425 - val_auc: 0.9840
Epoch 48/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1503 - tp: 18172.0000 - fp: 593.0000 - tn: 19959.0000 - fn: 2236.0000 - accuracy: 0.9309 - precision: 0.9684 - recall: 0.8904 - auc: 0.9867 - val_loss: 0.0887 - val_tp: 82.0000 - val_fp: 970.0000 - val_tn: 44512.0000 - val_fn: 5.0000 - val_accuracy: 0.9786 - val_precision: 0.0779 - val_recall: 0.9425 - val_auc: 0.9840
Epoch 49/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1572 - tp: 18217.0000 - fp: 750.0000 - tn: 19709.0000 - fn: 2284.0000 - accuracy: 0.9259 - precision: 0.9605 - recall: 0.8886 - auc: 0.9852 - val_loss: 0.0876 - val_tp: 82.0000 - val_fp: 964.0000 - val_tn: 44518.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0784 - val_recall: 0.9425 - val_auc: 0.9841
Epoch 50/1000
20/20 [==============================] - ETA: 0s - loss: 0.1529 - tp: 18230.0000 - fp: 696.0000 - tn: 19874.0000 - fn: 2160.0000 - accuracy: 0.9303 - precision: 0.9632 - recall: 0.8941 - auc: 0.9860Restoring model weights from the end of the best epoch.
20/20 [==============================] - 0s 23ms/step - loss: 0.1529 - tp: 18230.0000 - fp: 696.0000 - tn: 19874.0000 - fn: 2160.0000 - accuracy: 0.9303 - precision: 0.9632 - recall: 0.8941 - auc: 0.9860 - val_loss: 0.0860 - val_tp: 82.0000 - val_fp: 941.0000 - val_tn: 44541.0000 - val_fn: 5.0000 - val_accuracy: 0.9792 - val_precision: 0.0802 - val_recall: 0.9425 - val_auc: 0.9843
Epoch 00050: early stopping

훈련 이력 확인

 plot_metrics(resampled_history)
 

png

측정 항목 평가

 train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
 
 resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
 
loss :  0.09607589244842529
tp :  84.0
fp :  1195.0
tn :  55676.0
fn :  7.0
accuracy :  0.9788982272148132
precision :  0.06567630916833878
recall :  0.9230769276618958
auc :  0.9697299599647522

Legitimate Transactions Detected (True Negatives):  55676
Legitimate Transactions Incorrectly Detected (False Positives):  1195
Fraudulent Transactions Missed (False Negatives):  7
Fraudulent Transactions Detected (True Positives):  84
Total Fraudulent Transactions:  91

png

ROC 플롯

 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled,  color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled,  color=colors[2], linestyle='--')
plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa4bc66c9b0>

png

이 학습서를 문제점에 적용

불균형 데이터 분류는 배울 샘플이 거의 없기 때문에 본질적으로 어려운 작업입니다. 항상 데이터부터 먼저 시작하고 가능한 한 많은 샘플을 수집하고 모델이 소수 클래스를 최대한 활용할 수 있도록 어떤 기능이 관련 될 수 있는지에 대해 실질적으로 생각해야합니다. 어떤 시점에서 모델이 원하는 결과를 개선하고 산출하는 데 어려움을 겪을 수 있으므로 문제의 상황과 다른 유형의 오류 간의 상충 관계를 염두에 두는 것이 중요합니다.