このページは Cloud Translation API によって翻訳されました。
Switch to English

不均衡なデータの分類

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

このチュートリアルでは、1つのクラスの例の数が他のクラスの例を大幅に上回っている、非常に不均衡なデータセットを分類する方法を示します。 Kaggleでホストされているクレジットカード詐欺検出データセットを操作します。合計284,807件のトランザクションから、492件の不正なトランザクションを検出することを目的としています。 Kerasを使用してモデルとクラスの重みを定義し、モデルが不均衡なデータから学習できるようにします。 。

このチュートリアルには、次の完全なコードが含まれています。

  • Pandasを使用してCSVファイルを読み込みます。
  • トレーニング、検証、およびテストセットを作成します。
  • Kerasを使用してモデルを定義およびトレーニングします(クラスの重みの設定を含む)。
  • さまざまなメトリック(精度と再現率を含む)を使用してモデルを評価します。
  • 次のような不均衡なデータを処理するための一般的な手法を試してください。
    • クラスの重み付け
    • オーバーサンプリング

セットアップ

 import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
 
 mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
 

データの処理と調査

Kaggle Credit Card Fraudデータセットをダウンロードする

Pandasは、構造化データを読み込んで操作するための多くの便利なユーティリティを備えたPythonライブラリであり、CSVをデータフレームにダウンロードするために使用できます。

 file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
 
 raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
 

クラスラベルの不均衡を調べる

データセットの不均衡を見てみましょう:

 neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
 
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)


これは、陽性サンプルのごく一部を示しています。

データの整理、分割、正規化

生データにはいくつかの問題があります。まず、[ Time列と[ Amount列は可変であるため、直接使用できません。 Time列を削除し(意味が明確でないため)、 Amount列のログを取得して範囲を縮小します。

 cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps=0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)
 

データセットをトレーニング、検証、およびテストセットに分割します。検証セットは、モデルのフィッティング中に損失とメトリックを評価するために使用されますが、モデルはこのデータに適合していません。テストセットはトレーニングフェーズでは完全に使用されず、モデルが新しいデータにどれだけ一般化するかを評価するために最後にのみ使用されます。これは、トレーニングデータの不足による過剰適合が重大な懸念事項である不均衡なデータセットで特に重要です。

 # Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)
 

sklearn StandardScalerを使用して入力特徴を正規化します。これにより、平均が0に、標準偏差が1に設定されます。

 scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)

 
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

データ分布を見る

次に、いくつかの特徴についてポジティブとネガティブの例の分布を比較します。この時点で自問すべき良い質問は次のとおりです。

  • これらの分布は意味がありますか?
    • はい。入力を正規化しましたが、これらは主に+/- 2範囲に集中しています。
  • 分布の違いを確認できますか?
    • はい、肯定的な例には、極端な値の割合がはるかに高くなっています。
 pos_df = pd.DataFrame(train_features[ bool_train_labels], columns = train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns = train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
_ = plt.suptitle("Negative distribution")
 

png

png

モデルと指標を定義する

密に接続された非表示層、過剰適合を減らすためのドロップアウト層、およびトランザクションが不正である可能性を返す出力シグモイド層を含む単純なニューラルネットワークを作成する関数を定義します。

 METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics = METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model
 

有用なメトリックを理解する

上記で定義されたいくつかのメトリックがあり、パフォーマンスを評価するときに役立つモデルによって計算できることに注意してください。

  • 陰性と陽性は、 誤って分類されたサンプルです
  • 陰性と陽性は正しく分類されたサンプルです
  • 精度は、正しく分類された例のパーセンテージ> $ \ frac {\ text {true samples}} {\ text {total samples}} $
  • 精度は、正しく分類された予測陽性の割合です> $ \ frac {\ text {true positives}} {\ text {true positives + false positives}} $
  • 再現率は正しく分類された実際のポジティブの割合です> $ \ frac {\ text {true positives}} {\ text {true positives + false negatives}} $
  • AUCは、受信者操作特性曲線の曲線下面積(ROC-AUC)を指します。このメトリックは、分類子がランダムなポジティブサンプルをランダムなネガティブサンプルよりも高くランク付けする確率に等しくなります。

続きを読む:

ベースラインモデル

モデルを構築する

次に、前に定義した関数を使用してモデルを作成してトレーニングします。 2048のデフォルトのバッチサイズよりも大きいモデルを使用してモデルが適合していることに注意してください。これは、各バッチに少数の陽性サンプルが含まれる可能性が高いことを確認するために重要です。バッチサイズが小さすぎる場合、不正なトランザクションから学習することはおそらくないでしょう。

 EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
 
 model = make_model()
model.summary()
 
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________

モデルをテスト実行します。

 model.predict(train_features[:10])
 
array([[0.5788107 ],
       [0.44979692],
       [0.5427961 ],
       [0.5985188 ],
       [0.7758075 ],
       [0.3417888 ],
       [0.39359283],
       [0.5399953 ],
       [0.3551327 ],
       [0.47230086]], dtype=float32)

オプション:正しい初期バイアスを設定します。

これらの最初の推測は素晴らしいものではありません。あなたはデータセットが不均衡であることを知っています。それを反映するように出力層のバイアスを設定します(参照: トレーニングニューラルネットワークのレシピ: "init well" )。これは、初期の収束に役立ちます。

デフォルトのバイアス初期化では、損失は約math.log(2) = 0.69314

 results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
 
Loss: 0.7817

設定する正しいバイアスは、以下から導出できます。

$$ p_0 = pos /(pos + neg)= 1 /(1 + e ^ {-b_0})$$
$$ b_0 = -log_e(1 / p_0-1)$$
$$ b_0 = log_e(pos / neg)$$
 initial_bias = np.log([pos/neg])
initial_bias
 
array([-6.35935934])

これを初期バイアスとして設定すると、モデルはより合理的な初期推測を提供します。

近いはずです: pos/total = 0.0018

 model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])
 
array([[0.00093563],
       [0.00187903],
       [0.00109238],
       [0.00117128],
       [0.00134988],
       [0.00090826],
       [0.00099455],
       [0.00154405],
       [0.00100204],
       [0.0004291 ]], dtype=float32)

この初期化では、初期損失はおよそ次のようになります。

$$-p_0log(p_0)-(1-p_0)log(1-p_0)= 0.01317 $$
 results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
 
Loss: 0.0146

この初期の損失は、単純な初期化の場合の約50分の1です。

このようにして、モデルは最初のいくつかのエポックを費やして、ポジティブな例がありそうもないことを学ぶだけで済みます。これにより、トレーニング中の損失のプロットも読みやすくなります。

最初の重みのチェックポイント

さまざまなトレーニングの実行を比較できるようにするには、この初期モデルの重みをチェックポイントファイルに保存し、トレーニングの前に各モデルにロードします。

 initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights')
model.save_weights(initial_weights)
 

バイアス修正が役立つことを確認する

先に進む前に、慎重なバイアスの初期化が実際に役立つことをすばやく確認してください。

この慎重な初期化の有無にかかわらず、20エポックについてモデルをトレーニングし、損失を比較します。

 model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
 
 model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
 
 def plot_loss(history, label, n):
  # Use a log scale to show the wide range of values.
  plt.semilogy(history.epoch,  history.history['loss'],
               color=colors[n], label='Train '+label)
  plt.semilogy(history.epoch,  history.history['val_loss'],
          color=colors[n], label='Val '+label,
          linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  
  plt.legend()
 
 plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)
 

png

上の図から明らかです:検証の損失に関しては、この問題では、この慎重な初期化によって明確な利点が得られます。

モデルをトレーニングする

 model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels))
 
Epoch 1/100
90/90 [==============================] - 1s 13ms/step - loss: 0.0112 - tp: 100.0000 - fp: 25.0000 - tn: 227419.0000 - fn: 301.0000 - accuracy: 0.9986 - precision: 0.8000 - recall: 0.2494 - auc: 0.7615 - val_loss: 0.0067 - val_tp: 15.0000 - val_fp: 2.0000 - val_tn: 45480.0000 - val_fn: 72.0000 - val_accuracy: 0.9984 - val_precision: 0.8824 - val_recall: 0.1724 - val_auc: 0.9077
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0075 - tp: 108.0000 - fp: 24.0000 - tn: 181938.0000 - fn: 206.0000 - accuracy: 0.9987 - precision: 0.8182 - recall: 0.3439 - auc: 0.8491 - val_loss: 0.0046 - val_tp: 45.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 42.0000 - val_accuracy: 0.9989 - val_precision: 0.8824 - val_recall: 0.5172 - val_auc: 0.9308
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0065 - tp: 138.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 176.0000 - accuracy: 0.9989 - precision: 0.8364 - recall: 0.4395 - auc: 0.8567 - val_loss: 0.0040 - val_tp: 54.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 33.0000 - val_accuracy: 0.9991 - val_precision: 0.8852 - val_recall: 0.6207 - val_auc: 0.9365
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0060 - tp: 154.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 160.0000 - accuracy: 0.9989 - precision: 0.8235 - recall: 0.4904 - auc: 0.8848 - val_loss: 0.0037 - val_tp: 61.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 26.0000 - val_accuracy: 0.9993 - val_precision: 0.8841 - val_recall: 0.7011 - val_auc: 0.9422
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0057 - tp: 157.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 157.0000 - accuracy: 0.9989 - precision: 0.8135 - recall: 0.5000 - auc: 0.8982 - val_loss: 0.0035 - val_tp: 62.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8857 - val_recall: 0.7126 - val_auc: 0.9422
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0057 - tp: 152.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 162.0000 - accuracy: 0.9989 - precision: 0.8261 - recall: 0.4841 - auc: 0.8934 - val_loss: 0.0033 - val_tp: 65.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8904 - val_recall: 0.7471 - val_auc: 0.9479
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0052 - tp: 174.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 140.0000 - accuracy: 0.9991 - precision: 0.8529 - recall: 0.5541 - auc: 0.8983 - val_loss: 0.0032 - val_tp: 66.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8919 - val_recall: 0.7586 - val_auc: 0.9479
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0054 - tp: 161.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 153.0000 - accuracy: 0.9990 - precision: 0.8342 - recall: 0.5127 - auc: 0.8983 - val_loss: 0.0031 - val_tp: 66.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8919 - val_recall: 0.7586 - val_auc: 0.9479
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 167.0000 - fp: 37.0000 - tn: 181925.0000 - fn: 147.0000 - accuracy: 0.9990 - precision: 0.8186 - recall: 0.5318 - auc: 0.9064 - val_loss: 0.0030 - val_tp: 65.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8904 - val_recall: 0.7471 - val_auc: 0.9479
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 156.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 158.0000 - accuracy: 0.9989 - precision: 0.8211 - recall: 0.4968 - auc: 0.9046 - val_loss: 0.0029 - val_tp: 67.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8933 - val_recall: 0.7701 - val_auc: 0.9479
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0048 - tp: 165.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8376 - recall: 0.5255 - auc: 0.9063 - val_loss: 0.0029 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9479
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0051 - tp: 165.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8250 - recall: 0.5255 - auc: 0.9110 - val_loss: 0.0028 - val_tp: 67.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8933 - val_recall: 0.7701 - val_auc: 0.9480
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 157.0000 - fp: 29.0000 - tn: 181933.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8441 - recall: 0.5000 - auc: 0.9031 - val_loss: 0.0028 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9479
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 160.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 154.0000 - accuracy: 0.9990 - precision: 0.8205 - recall: 0.5096 - auc: 0.8934 - val_loss: 0.0027 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9479
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0049 - tp: 168.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8235 - recall: 0.5350 - auc: 0.9031 - val_loss: 0.0027 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9479
Epoch 16/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0046 - tp: 169.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 145.0000 - accuracy: 0.9990 - precision: 0.8492 - recall: 0.5382 - auc: 0.9143 - val_loss: 0.0027 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 181.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8498 - recall: 0.5764 - auc: 0.9144 - val_loss: 0.0027 - val_tp: 70.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.8974 - val_recall: 0.8046 - val_auc: 0.9537
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 181.0000 - fp: 29.0000 - tn: 181933.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8619 - recall: 0.5764 - auc: 0.9112 - val_loss: 0.0026 - val_tp: 69.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8961 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 172.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8431 - recall: 0.5478 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 68.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8947 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 20/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 177.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 137.0000 - accuracy: 0.9991 - precision: 0.8349 - recall: 0.5637 - auc: 0.9128 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 21/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0045 - tp: 176.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8462 - recall: 0.5605 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.9167 - val_recall: 0.7586 - val_auc: 0.9537
Epoch 22/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 163.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 151.0000 - accuracy: 0.9990 - precision: 0.8316 - recall: 0.5191 - auc: 0.9096 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 23/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 183.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8281 - recall: 0.5828 - auc: 0.9113 - val_loss: 0.0026 - val_tp: 66.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.9041 - val_recall: 0.7586 - val_auc: 0.9537
Epoch 24/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 168.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8400 - recall: 0.5350 - auc: 0.9128 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 25/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 179.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 135.0000 - accuracy: 0.9991 - precision: 0.8483 - recall: 0.5701 - auc: 0.9161 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 26/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 173.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 141.0000 - accuracy: 0.9990 - precision: 0.8199 - recall: 0.5510 - auc: 0.9208 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 27/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 172.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8431 - recall: 0.5478 - auc: 0.9081 - val_loss: 0.0026 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 28/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 181.0000 - fp: 39.0000 - tn: 181923.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8227 - recall: 0.5764 - auc: 0.9193 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9189 - val_recall: 0.7816 - val_auc: 0.9537
Epoch 29/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 177.0000 - fp: 38.0000 - tn: 181924.0000 - fn: 137.0000 - accuracy: 0.9990 - precision: 0.8233 - recall: 0.5637 - auc: 0.9305 - val_loss: 0.0025 - val_tp: 67.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.9054 - val_recall: 0.7701 - val_auc: 0.9538
Epoch 30/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 168.0000 - fp: 31.0000 - tn: 181931.0000 - fn: 146.0000 - accuracy: 0.9990 - precision: 0.8442 - recall: 0.5350 - auc: 0.9161 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 31/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 172.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8309 - recall: 0.5478 - auc: 0.9176 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 32/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 188.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8507 - recall: 0.5987 - auc: 0.9162 - val_loss: 0.0025 - val_tp: 70.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9091 - val_recall: 0.8046 - val_auc: 0.9538
Epoch 33/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 184.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8720 - recall: 0.5860 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 72.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.8276 - val_auc: 0.9537
Epoch 34/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 185.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 129.0000 - accuracy: 0.9991 - precision: 0.8486 - recall: 0.5892 - auc: 0.9273 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 35/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 178.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8318 - recall: 0.5669 - auc: 0.9160 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 36/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 171.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 143.0000 - accuracy: 0.9990 - precision: 0.8382 - recall: 0.5446 - auc: 0.9192 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 37/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 189.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 125.0000 - accuracy: 0.9991 - precision: 0.8438 - recall: 0.6019 - auc: 0.9242 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 38/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 185.0000 - fp: 25.0000 - tn: 181937.0000 - fn: 129.0000 - accuracy: 0.9992 - precision: 0.8810 - recall: 0.5892 - auc: 0.9176 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 39/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 181.0000 - fp: 35.0000 - tn: 181927.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8380 - recall: 0.5764 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9189 - val_recall: 0.7816 - val_auc: 0.9538
Epoch 40/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 175.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 139.0000 - accuracy: 0.9991 - precision: 0.8537 - recall: 0.5573 - auc: 0.9209 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 41/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 180.0000 - fp: 32.0000 - tn: 181930.0000 - fn: 134.0000 - accuracy: 0.9991 - precision: 0.8491 - recall: 0.5732 - auc: 0.9320 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9537
Epoch 42/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 188.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8468 - recall: 0.5987 - auc: 0.9209 - val_loss: 0.0025 - val_tp: 71.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8987 - val_recall: 0.8161 - val_auc: 0.9538
Epoch 43/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 176.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8421 - recall: 0.5605 - auc: 0.9225 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 44/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 172.0000 - fp: 37.0000 - tn: 181925.0000 - fn: 142.0000 - accuracy: 0.9990 - precision: 0.8230 - recall: 0.5478 - auc: 0.9129 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 45/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 175.0000 - fp: 36.0000 - tn: 181926.0000 - fn: 139.0000 - accuracy: 0.9990 - precision: 0.8294 - recall: 0.5573 - auc: 0.9368 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 46/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 176.0000 - fp: 33.0000 - tn: 181929.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8421 - recall: 0.5605 - auc: 0.9240 - val_loss: 0.0025 - val_tp: 69.0000 - val_fp: 7.0000 - val_tn: 45475.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9079 - val_recall: 0.7931 - val_auc: 0.9538
Epoch 47/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 178.0000 - fp: 27.0000 - tn: 181935.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8683 - recall: 0.5669 - auc: 0.9273 - val_loss: 0.0025 - val_tp: 72.0000 - val_fp: 8.0000 - val_tn: 45474.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.8276 - val_auc: 0.9537
Epoch 48/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 198.0000 - fp: 34.0000 - tn: 181928.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8534 - recall: 0.6306 - auc: 0.9256 - val_loss: 0.0025 - val_tp: 68.0000 - val_fp: 5.0000 - val_tn: 45477.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9315 - val_recall: 0.7816 - val_auc: 0.9538
Epoch 49/100
85/90 [===========================>..] - ETA: 0s - loss: 0.0043 - tp: 162.0000 - fp: 29.0000 - tn: 173750.0000 - fn: 139.0000 - accuracy: 0.9990 - precision: 0.8482 - recall: 0.5382 - auc: 0.9157Restoring model weights from the end of the best epoch.
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 171.0000 - fp: 30.0000 - tn: 181932.0000 - fn: 143.0000 - accuracy: 0.9991 - precision: 0.8507 - recall: 0.5446 - auc: 0.9191 - val_loss: 0.0024 - val_tp: 69.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9200 - val_recall: 0.7931 - val_auc: 0.9537
Epoch 00049: early stopping

トレーニング履歴を確認する

このセクションでは、トレーニングと検証セットでモデルの精度と損失のプロットを作成します。これらは、このチュートリアルで詳しく学ぶことができる過剰適合をチェックするのに役立ちます。

さらに、上で作成した任意のメトリックのこれらのプロットを作成できます。例として偽陰性が含まれています。

 def plot_metrics(history):
  metrics =  ['loss', 'auc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch,  history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()

 
 plot_metrics(baseline_history)
 

png

指標を評価する

混同行列を使用して、X軸が予測ラベルでY軸が実際のラベルである実際のラベルと予測ラベルを要約できます。

 train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
 
 def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))
 

テストデータセットでモデルを評価し、上で作成したメトリックの結果を表示します。

 baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
 
loss :  0.002310588490217924
tp :  69.0
fp :  5.0
tn :  56866.0
fn :  22.0
accuracy :  0.9995260238647461
precision :  0.9324324131011963
recall :  0.7582417726516724
auc :  0.9557874202728271

Legitimate Transactions Detected (True Negatives):  56866
Legitimate Transactions Incorrectly Detected (False Positives):  5
Fraudulent Transactions Missed (False Negatives):  22
Fraudulent Transactions Detected (True Positives):  69
Total Fraudulent Transactions:  91

png

モデルがすべてを完全に予測した場合、これは対角行列であり、誤った予測を示す主対角からの値がゼロになります。この場合、マトリックスは、誤検知が比較的少ないことを示しています。これは、誤ってフラグが付けられた正当なトランザクションが比較的少ないことを意味します。ただし、偽陽性の数を増やすコストがあっても、偽陰性をさらに少なくしたい場合があります。このトレードオフは、偽陰性が不正なトランザクションの通過を許可するため、望ましい場合があります。一方、偽陽性は、顧客にカードのアクティビティを確認するよう依頼するメールを送信する可能性があるためです。

ROCをプロットする

ROCをプロットします。このプロットは、出力しきい値を調整するだけでモデルが到達できるパフォーマンスの範囲が一目でわかるので便利です。

 def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
 
 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa50c5adef0>

png

精度は比較的高いように見えますが、再現率とROC曲線の下の面積(AUC)は希望するほど高くありません。精度と再現率の両方を最大化しようとすると、分類子はしばしば課題に直面します。これは、不均衡なデータセットを扱う場合に特に当てはまります。気になる問題のコンテキストで、さまざまなタイプのエラーのコストを考慮することが重要です。この例では、偽陰性(不正なトランザクションが見逃されている)は金銭的コストを伴う可能性がありますが、偽陽性(トランザクションが不正であると誤ってフラグが立てられている)はユーザーの幸福度を低下させる可能性があります。

クラスの重み

クラスの重みを計算する

目的は不正なトランザクションを特定することですが、処理するポジティブサンプルがあまり多くないため、利用可能ないくつかの例を分類器に大きく重み付けする必要があります。これを行うには、パラメータを介して各クラスのKerasの重みを渡します。これらにより、モデルは、十分に表現されていないクラスの例に「より注意を払う」ことになります。

 # Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
 
Weight for class 0: 0.50
Weight for class 1: 289.44

クラスの重みでモデルをトレーニングする

次に、クラスの重みを使用してモデルを再トレーニングして評価し、それが予測にどのように影響するかを確認します。

 weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight) 
 
Epoch 1/100
90/90 [==============================] - 1s 15ms/step - loss: 2.5149 - tp: 105.0000 - fp: 66.0000 - tn: 238767.0000 - fn: 300.0000 - accuracy: 0.9985 - precision: 0.6140 - recall: 0.2593 - auc: 0.7803 - val_loss: 0.0067 - val_tp: 25.0000 - val_fp: 6.0000 - val_tn: 45476.0000 - val_fn: 62.0000 - val_accuracy: 0.9985 - val_precision: 0.8065 - val_recall: 0.2874 - val_auc: 0.9211
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 1.2482 - tp: 145.0000 - fp: 124.0000 - tn: 181838.0000 - fn: 169.0000 - accuracy: 0.9984 - precision: 0.5390 - recall: 0.4618 - auc: 0.8560 - val_loss: 0.0062 - val_tp: 68.0000 - val_fp: 12.0000 - val_tn: 45470.0000 - val_fn: 19.0000 - val_accuracy: 0.9993 - val_precision: 0.8500 - val_recall: 0.7816 - val_auc: 0.9408
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.8972 - tp: 177.0000 - fp: 237.0000 - tn: 181725.0000 - fn: 137.0000 - accuracy: 0.9979 - precision: 0.4275 - recall: 0.5637 - auc: 0.8876 - val_loss: 0.0079 - val_tp: 73.0000 - val_fp: 16.0000 - val_tn: 45466.0000 - val_fn: 14.0000 - val_accuracy: 0.9993 - val_precision: 0.8202 - val_recall: 0.8391 - val_auc: 0.9518
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6983 - tp: 210.0000 - fp: 387.0000 - tn: 181575.0000 - fn: 104.0000 - accuracy: 0.9973 - precision: 0.3518 - recall: 0.6688 - auc: 0.9028 - val_loss: 0.0098 - val_tp: 74.0000 - val_fp: 19.0000 - val_tn: 45463.0000 - val_fn: 13.0000 - val_accuracy: 0.9993 - val_precision: 0.7957 - val_recall: 0.8506 - val_auc: 0.9600
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6417 - tp: 220.0000 - fp: 583.0000 - tn: 181379.0000 - fn: 94.0000 - accuracy: 0.9963 - precision: 0.2740 - recall: 0.7006 - auc: 0.9084 - val_loss: 0.0119 - val_tp: 74.0000 - val_fp: 25.0000 - val_tn: 45457.0000 - val_fn: 13.0000 - val_accuracy: 0.9992 - val_precision: 0.7475 - val_recall: 0.8506 - val_auc: 0.9777
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5846 - tp: 232.0000 - fp: 977.0000 - tn: 180985.0000 - fn: 82.0000 - accuracy: 0.9942 - precision: 0.1919 - recall: 0.7389 - auc: 0.9048 - val_loss: 0.0148 - val_tp: 74.0000 - val_fp: 34.0000 - val_tn: 45448.0000 - val_fn: 13.0000 - val_accuracy: 0.9990 - val_precision: 0.6852 - val_recall: 0.8506 - val_auc: 0.9802
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.5404 - tp: 234.0000 - fp: 1464.0000 - tn: 180498.0000 - fn: 80.0000 - accuracy: 0.9915 - precision: 0.1378 - recall: 0.7452 - auc: 0.9190 - val_loss: 0.0183 - val_tp: 74.0000 - val_fp: 50.0000 - val_tn: 45432.0000 - val_fn: 13.0000 - val_accuracy: 0.9986 - val_precision: 0.5968 - val_recall: 0.8506 - val_auc: 0.9823
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4714 - tp: 241.0000 - fp: 1862.0000 - tn: 180100.0000 - fn: 73.0000 - accuracy: 0.9894 - precision: 0.1146 - recall: 0.7675 - auc: 0.9252 - val_loss: 0.0225 - val_tp: 76.0000 - val_fp: 84.0000 - val_tn: 45398.0000 - val_fn: 11.0000 - val_accuracy: 0.9979 - val_precision: 0.4750 - val_recall: 0.8736 - val_auc: 0.9851
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4329 - tp: 247.0000 - fp: 2508.0000 - tn: 179454.0000 - fn: 67.0000 - accuracy: 0.9859 - precision: 0.0897 - recall: 0.7866 - auc: 0.9345 - val_loss: 0.0282 - val_tp: 76.0000 - val_fp: 170.0000 - val_tn: 45312.0000 - val_fn: 11.0000 - val_accuracy: 0.9960 - val_precision: 0.3089 - val_recall: 0.8736 - val_auc: 0.9873
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4467 - tp: 249.0000 - fp: 3175.0000 - tn: 178787.0000 - fn: 65.0000 - accuracy: 0.9822 - precision: 0.0727 - recall: 0.7930 - auc: 0.9210 - val_loss: 0.0341 - val_tp: 78.0000 - val_fp: 282.0000 - val_tn: 45200.0000 - val_fn: 9.0000 - val_accuracy: 0.9936 - val_precision: 0.2167 - val_recall: 0.8966 - val_auc: 0.9881
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3947 - tp: 260.0000 - fp: 3569.0000 - tn: 178393.0000 - fn: 54.0000 - accuracy: 0.9801 - precision: 0.0679 - recall: 0.8280 - auc: 0.9290 - val_loss: 0.0394 - val_tp: 78.0000 - val_fp: 346.0000 - val_tn: 45136.0000 - val_fn: 9.0000 - val_accuracy: 0.9922 - val_precision: 0.1840 - val_recall: 0.8966 - val_auc: 0.9877
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3694 - tp: 257.0000 - fp: 4294.0000 - tn: 177668.0000 - fn: 57.0000 - accuracy: 0.9761 - precision: 0.0565 - recall: 0.8185 - auc: 0.9418 - val_loss: 0.0473 - val_tp: 78.0000 - val_fp: 504.0000 - val_tn: 44978.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 0.1340 - val_recall: 0.8966 - val_auc: 0.9879
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3479 - tp: 262.0000 - fp: 4886.0000 - tn: 177076.0000 - fn: 52.0000 - accuracy: 0.9729 - precision: 0.0509 - recall: 0.8344 - auc: 0.9403 - val_loss: 0.0539 - val_tp: 78.0000 - val_fp: 586.0000 - val_tn: 44896.0000 - val_fn: 9.0000 - val_accuracy: 0.9869 - val_precision: 0.1175 - val_recall: 0.8966 - val_auc: 0.9881
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3653 - tp: 263.0000 - fp: 5360.0000 - tn: 176602.0000 - fn: 51.0000 - accuracy: 0.9703 - precision: 0.0468 - recall: 0.8376 - auc: 0.9370 - val_loss: 0.0610 - val_tp: 78.0000 - val_fp: 664.0000 - val_tn: 44818.0000 - val_fn: 9.0000 - val_accuracy: 0.9852 - val_precision: 0.1051 - val_recall: 0.8966 - val_auc: 0.9876
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3673 - tp: 262.0000 - fp: 5820.0000 - tn: 176142.0000 - fn: 52.0000 - accuracy: 0.9678 - precision: 0.0431 - recall: 0.8344 - auc: 0.9316 - val_loss: 0.0658 - val_tp: 78.0000 - val_fp: 715.0000 - val_tn: 44767.0000 - val_fn: 9.0000 - val_accuracy: 0.9841 - val_precision: 0.0984 - val_recall: 0.8966 - val_auc: 0.9877
Epoch 16/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3228 - tp: 262.0000 - fp: 6230.0000 - tn: 175732.0000 - fn: 52.0000 - accuracy: 0.9655 - precision: 0.0404 - recall: 0.8344 - auc: 0.9445 - val_loss: 0.0716 - val_tp: 79.0000 - val_fp: 805.0000 - val_tn: 44677.0000 - val_fn: 8.0000 - val_accuracy: 0.9822 - val_precision: 0.0894 - val_recall: 0.9080 - val_auc: 0.9877
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3299 - tp: 268.0000 - fp: 6572.0000 - tn: 175390.0000 - fn: 46.0000 - accuracy: 0.9637 - precision: 0.0392 - recall: 0.8535 - auc: 0.9423 - val_loss: 0.0757 - val_tp: 81.0000 - val_fp: 846.0000 - val_tn: 44636.0000 - val_fn: 6.0000 - val_accuracy: 0.9813 - val_precision: 0.0874 - val_recall: 0.9310 - val_auc: 0.9878
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2522 - tp: 276.0000 - fp: 6934.0000 - tn: 175028.0000 - fn: 38.0000 - accuracy: 0.9618 - precision: 0.0383 - recall: 0.8790 - auc: 0.9610 - val_loss: 0.0779 - val_tp: 81.0000 - val_fp: 874.0000 - val_tn: 44608.0000 - val_fn: 6.0000 - val_accuracy: 0.9807 - val_precision: 0.0848 - val_recall: 0.9310 - val_auc: 0.9877
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3607 - tp: 264.0000 - fp: 6790.0000 - tn: 175172.0000 - fn: 50.0000 - accuracy: 0.9625 - precision: 0.0374 - recall: 0.8408 - auc: 0.9303 - val_loss: 0.0781 - val_tp: 81.0000 - val_fp: 865.0000 - val_tn: 44617.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0856 - val_recall: 0.9310 - val_auc: 0.9879
Epoch 20/100
89/90 [============================>.] - ETA: 0s - loss: 0.2977 - tp: 269.0000 - fp: 6769.0000 - tn: 175189.0000 - fn: 45.0000 - accuracy: 0.9626 - precision: 0.0382 - recall: 0.8567 - auc: 0.9488Restoring model weights from the end of the best epoch.
90/90 [==============================] - 1s 6ms/step - loss: 0.2977 - tp: 269.0000 - fp: 6769.0000 - tn: 175193.0000 - fn: 45.0000 - accuracy: 0.9626 - precision: 0.0382 - recall: 0.8567 - auc: 0.9488 - val_loss: 0.0780 - val_tp: 81.0000 - val_fp: 853.0000 - val_tn: 44629.0000 - val_fn: 6.0000 - val_accuracy: 0.9811 - val_precision: 0.0867 - val_recall: 0.9310 - val_auc: 0.9879
Epoch 00020: early stopping

トレーニング履歴を確認する

 plot_metrics(weighted_history)
 

png

指標を評価する

 train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
 
 weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
 
loss :  0.03226418048143387
tp :  82.0
fp :  352.0
tn :  56519.0
fn :  9.0
accuracy :  0.993662416934967
precision :  0.18894009292125702
recall :  0.901098906993866
auc :  0.9671803712844849

Legitimate Transactions Detected (True Negatives):  56519
Legitimate Transactions Incorrectly Detected (False Positives):  352
Fraudulent Transactions Missed (False Negatives):  9
Fraudulent Transactions Detected (True Positives):  82
Total Fraudulent Transactions:  91

png

ここでは、クラスの重みを使用すると、誤検出が多くなるため、精度と精度が低くなりますが、逆に、モデルがより多くの正の検出を検出したため、再現率とAUCが高くなります。精度は低くなりますが、このモデルは再現率が高くなります(不正なトランザクションを特定します)。もちろん、両方のタイプのエラーにはコストがかかります(正当なトランザクションに非常に多くのフラグを立てて不正行為としてフラグを立てることによってユーザーにバグを与えたくないでしょう)。アプリケーションのこれらの異なるタイプのエラー間のトレードオフを慎重に検討してください。

ROCをプロットする

 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa54c0729e8>

png

オーバーサンプリング

少数派クラスをオーバーサンプリングする

関連するアプローチは、少数派クラスをオーバーサンプリングしてデータセットをリサンプリングすることです。

 pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
 

NumPyの使用

正の例から適切な数のランダムインデックスを選択することにより、データセットを手動でバランス調整できます。

 ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
 
(181962, 29)
 resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
 
(363924, 29)

tf.data使用

tf.dataを使用している場合、バランスのとれた例を作成する最も簡単な方法は、 positiveデータセットとnegativeデータセットから始めて、それらをマージすることです。その他の例については、tf.dataガイドを参照してください。

 BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)
 

各データセットは(feature, label)ペアを提供します:

 for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
 
Features:
 [ 0.23104754  0.83661044 -0.31875356  1.9796369   1.28403692  0.07389102
  1.03350673 -0.11568355 -1.54396817  0.88004244 -1.66944551 -0.24324391
  0.45900013  0.14583622 -2.06637388  0.42470592 -0.94489216 -0.83112221
 -1.83416278 -0.34138858  0.14130878  0.51019975  0.08224586  0.6642136
 -1.39031637 -0.42194185  0.22525572  0.28277796 -4.86369823]

Label:  1

experimental.sample_from_datasetsを使用して2つをマージします:

 resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
 
 for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
 
0.49609375

このデータセットを使用するには、エポックあたりのステップ数が必要です。

この場合の「エポック」の定義はあまり明確ではありません。それは、それぞれの否定的な例を一度見るのに必要なバッチの数だとしましょう:

 resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
 
278.0

オーバーサンプリングされたデータでトレーニングする

ここで、クラスの重みを使用する代わりに、リサンプリングされたデータセットでモデルをトレーニングして、これらのメソッドの比較を確認してください。

 resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks = [early_stopping],
    validation_data=val_ds)
 
Epoch 1/100
278/278 [==============================] - 6s 23ms/step - loss: 0.4356 - tp: 223484.0000 - fp: 51288.0000 - tn: 290777.0000 - fn: 60757.0000 - accuracy: 0.8211 - precision: 0.8133 - recall: 0.7862 - auc: 0.8933 - val_loss: 0.2172 - val_tp: 79.0000 - val_fp: 1076.0000 - val_tn: 44406.0000 - val_fn: 8.0000 - val_accuracy: 0.9762 - val_precision: 0.0684 - val_recall: 0.9080 - val_auc: 0.9792
Epoch 2/100
278/278 [==============================] - 6s 20ms/step - loss: 0.2177 - tp: 246785.0000 - fp: 12557.0000 - tn: 271871.0000 - fn: 38131.0000 - accuracy: 0.9110 - precision: 0.9516 - recall: 0.8662 - auc: 0.9686 - val_loss: 0.1226 - val_tp: 80.0000 - val_fp: 951.0000 - val_tn: 44531.0000 - val_fn: 7.0000 - val_accuracy: 0.9790 - val_precision: 0.0776 - val_recall: 0.9195 - val_auc: 0.9835
Epoch 3/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1751 - tp: 250631.0000 - fp: 9797.0000 - tn: 275174.0000 - fn: 33742.0000 - accuracy: 0.9235 - precision: 0.9624 - recall: 0.8813 - auc: 0.9810 - val_loss: 0.0940 - val_tp: 82.0000 - val_fp: 966.0000 - val_tn: 44516.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0782 - val_recall: 0.9425 - val_auc: 0.9836
Epoch 4/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1532 - tp: 254169.0000 - fp: 9171.0000 - tn: 275694.0000 - fn: 30310.0000 - accuracy: 0.9307 - precision: 0.9652 - recall: 0.8935 - auc: 0.9861 - val_loss: 0.0802 - val_tp: 82.0000 - val_fp: 918.0000 - val_tn: 44564.0000 - val_fn: 5.0000 - val_accuracy: 0.9797 - val_precision: 0.0820 - val_recall: 0.9425 - val_auc: 0.9847
Epoch 5/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1372 - tp: 257034.0000 - fp: 9061.0000 - tn: 275758.0000 - fn: 27491.0000 - accuracy: 0.9358 - precision: 0.9659 - recall: 0.9034 - auc: 0.9892 - val_loss: 0.0720 - val_tp: 82.0000 - val_fp: 910.0000 - val_tn: 44572.0000 - val_fn: 5.0000 - val_accuracy: 0.9799 - val_precision: 0.0827 - val_recall: 0.9425 - val_auc: 0.9854
Epoch 6/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1260 - tp: 258997.0000 - fp: 9079.0000 - tn: 275819.0000 - fn: 25449.0000 - accuracy: 0.9394 - precision: 0.9661 - recall: 0.9105 - auc: 0.9911 - val_loss: 0.0666 - val_tp: 81.0000 - val_fp: 915.0000 - val_tn: 44567.0000 - val_fn: 6.0000 - val_accuracy: 0.9798 - val_precision: 0.0813 - val_recall: 0.9310 - val_auc: 0.9856
Epoch 7/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1167 - tp: 261100.0000 - fp: 9112.0000 - tn: 276180.0000 - fn: 22952.0000 - accuracy: 0.9437 - precision: 0.9663 - recall: 0.9192 - auc: 0.9925 - val_loss: 0.0623 - val_tp: 81.0000 - val_fp: 911.0000 - val_tn: 44571.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0817 - val_recall: 0.9310 - val_auc: 0.9858
Epoch 8/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1082 - tp: 263945.0000 - fp: 9428.0000 - tn: 275276.0000 - fn: 20695.0000 - accuracy: 0.9471 - precision: 0.9655 - recall: 0.9273 - auc: 0.9937 - val_loss: 0.0587 - val_tp: 81.0000 - val_fp: 910.0000 - val_tn: 44572.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0817 - val_recall: 0.9310 - val_auc: 0.9857
Epoch 9/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1014 - tp: 268108.0000 - fp: 10376.0000 - tn: 274312.0000 - fn: 16548.0000 - accuracy: 0.9527 - precision: 0.9627 - recall: 0.9419 - auc: 0.9944 - val_loss: 0.0543 - val_tp: 80.0000 - val_fp: 873.0000 - val_tn: 44609.0000 - val_fn: 7.0000 - val_accuracy: 0.9807 - val_precision: 0.0839 - val_recall: 0.9195 - val_auc: 0.9857
Epoch 10/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0951 - tp: 277520.0000 - fp: 12692.0000 - tn: 271795.0000 - fn: 7337.0000 - accuracy: 0.9648 - precision: 0.9563 - recall: 0.9742 - auc: 0.9950 - val_loss: 0.0495 - val_tp: 79.0000 - val_fp: 829.0000 - val_tn: 44653.0000 - val_fn: 8.0000 - val_accuracy: 0.9816 - val_precision: 0.0870 - val_recall: 0.9080 - val_auc: 0.9855
Epoch 11/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0895 - tp: 278865.0000 - fp: 12938.0000 - tn: 271719.0000 - fn: 5822.0000 - accuracy: 0.9670 - precision: 0.9557 - recall: 0.9795 - auc: 0.9955 - val_loss: 0.0450 - val_tp: 79.0000 - val_fp: 789.0000 - val_tn: 44693.0000 - val_fn: 8.0000 - val_accuracy: 0.9825 - val_precision: 0.0910 - val_recall: 0.9080 - val_auc: 0.9859
Epoch 12/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0842 - tp: 279845.0000 - fp: 13187.0000 - tn: 272121.0000 - fn: 4191.0000 - accuracy: 0.9695 - precision: 0.9550 - recall: 0.9852 - auc: 0.9960 - val_loss: 0.0410 - val_tp: 79.0000 - val_fp: 733.0000 - val_tn: 44749.0000 - val_fn: 8.0000 - val_accuracy: 0.9837 - val_precision: 0.0973 - val_recall: 0.9080 - val_auc: 0.9813
Epoch 13/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0792 - tp: 281765.0000 - fp: 12977.0000 - tn: 271393.0000 - fn: 3209.0000 - accuracy: 0.9716 - precision: 0.9560 - recall: 0.9887 - auc: 0.9963 - val_loss: 0.0389 - val_tp: 79.0000 - val_fp: 721.0000 - val_tn: 44761.0000 - val_fn: 8.0000 - val_accuracy: 0.9840 - val_precision: 0.0988 - val_recall: 0.9080 - val_auc: 0.9814
Epoch 14/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0754 - tp: 281962.0000 - fp: 13026.0000 - tn: 272154.0000 - fn: 2202.0000 - accuracy: 0.9733 - precision: 0.9558 - recall: 0.9923 - auc: 0.9966 - val_loss: 0.0348 - val_tp: 79.0000 - val_fp: 646.0000 - val_tn: 44836.0000 - val_fn: 8.0000 - val_accuracy: 0.9856 - val_precision: 0.1090 - val_recall: 0.9080 - val_auc: 0.9763
Epoch 15/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0722 - tp: 283858.0000 - fp: 12932.0000 - tn: 271419.0000 - fn: 1135.0000 - accuracy: 0.9753 - precision: 0.9564 - recall: 0.9960 - auc: 0.9967 - val_loss: 0.0331 - val_tp: 79.0000 - val_fp: 640.0000 - val_tn: 44842.0000 - val_fn: 8.0000 - val_accuracy: 0.9858 - val_precision: 0.1099 - val_recall: 0.9080 - val_auc: 0.9714
Epoch 16/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0689 - tp: 283059.0000 - fp: 12757.0000 - tn: 273004.0000 - fn: 524.0000 - accuracy: 0.9767 - precision: 0.9569 - recall: 0.9982 - auc: 0.9970 - val_loss: 0.0308 - val_tp: 79.0000 - val_fp: 583.0000 - val_tn: 44899.0000 - val_fn: 8.0000 - val_accuracy: 0.9870 - val_precision: 0.1193 - val_recall: 0.9080 - val_auc: 0.9667
Epoch 17/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0661 - tp: 283879.0000 - fp: 12340.0000 - tn: 272779.0000 - fn: 346.0000 - accuracy: 0.9777 - precision: 0.9583 - recall: 0.9988 - auc: 0.9971 - val_loss: 0.0289 - val_tp: 79.0000 - val_fp: 542.0000 - val_tn: 44940.0000 - val_fn: 8.0000 - val_accuracy: 0.9879 - val_precision: 0.1272 - val_recall: 0.9080 - val_auc: 0.9618
Epoch 18/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0635 - tp: 284858.0000 - fp: 12157.0000 - tn: 272120.0000 - fn: 209.0000 - accuracy: 0.9783 - precision: 0.9591 - recall: 0.9993 - auc: 0.9973 - val_loss: 0.0277 - val_tp: 79.0000 - val_fp: 511.0000 - val_tn: 44971.0000 - val_fn: 8.0000 - val_accuracy: 0.9886 - val_precision: 0.1339 - val_recall: 0.9080 - val_auc: 0.9621
Epoch 19/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0620 - tp: 284459.0000 - fp: 11978.0000 - tn: 272718.0000 - fn: 189.0000 - accuracy: 0.9786 - precision: 0.9596 - recall: 0.9993 - auc: 0.9973 - val_loss: 0.0261 - val_tp: 79.0000 - val_fp: 478.0000 - val_tn: 45004.0000 - val_fn: 8.0000 - val_accuracy: 0.9893 - val_precision: 0.1418 - val_recall: 0.9080 - val_auc: 0.9624
Epoch 20/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0600 - tp: 284950.0000 - fp: 11793.0000 - tn: 272572.0000 - fn: 29.0000 - accuracy: 0.9792 - precision: 0.9603 - recall: 0.9999 - auc: 0.9974 - val_loss: 0.0252 - val_tp: 79.0000 - val_fp: 463.0000 - val_tn: 45019.0000 - val_fn: 8.0000 - val_accuracy: 0.9897 - val_precision: 0.1458 - val_recall: 0.9080 - val_auc: 0.9626
Epoch 21/100
276/278 [============================>.] - ETA: 0s - loss: 0.0581 - tp: 282210.0000 - fp: 11270.0000 - tn: 271768.0000 - fn: 0.0000e+00 - accuracy: 0.9801 - precision: 0.9616 - recall: 1.0000 - auc: 0.9975Restoring model weights from the end of the best epoch.
278/278 [==============================] - 6s 22ms/step - loss: 0.0581 - tp: 284274.0000 - fp: 11360.0000 - tn: 273710.0000 - fn: 0.0000e+00 - accuracy: 0.9800 - precision: 0.9616 - recall: 1.0000 - auc: 0.9975 - val_loss: 0.0241 - val_tp: 79.0000 - val_fp: 444.0000 - val_tn: 45038.0000 - val_fn: 8.0000 - val_accuracy: 0.9901 - val_precision: 0.1511 - val_recall: 0.9080 - val_auc: 0.9628
Epoch 00021: early stopping

トレーニングプロセスが各勾配の更新でデータセット全体を考慮している場合、このオーバーサンプリングは基本的にクラスの重み付けと同じになります。

ただし、ここで行ったようにバッチでモデルをトレーニングすると、オーバーサンプリングされたデータにより、より滑らかな勾配信号が得られます。それぞれの正の例が大きな重みを持つ1つのバッチで示される代わりに、毎回多くの異なるバッチで示されます。小さな重量。

この滑らかな勾配信号により、モデルのトレーニングが容易になります。

トレーニング履歴を確認する

トレーニングデータは検証およびテストデータとは完全に異なる分布であるため、ここではメトリックの分布が異なることに注意してください。

 plot_metrics(resampled_history )
 

png

再トレーニング

バランスのとれたデータでのトレーニングは簡単なので、上記のトレーニング手順はすぐにオーバーフィットする可能性があります。

したがって、エポックを分割してcallbacks.EarlyStoppingを提供しcallbacks.EarlyStopping .EarlyStoppingは、トレーニングを停止するタイミングをより細かく制御しcallbacks.EarlyStopping

 resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch = 20,
    epochs=10*EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_ds))
 
Epoch 1/1000
20/20 [==============================] - 1s 60ms/step - loss: 1.0656 - tp: 9507.0000 - fp: 7370.0000 - tn: 58667.0000 - fn: 10985.0000 - accuracy: 0.7879 - precision: 0.5633 - recall: 0.4639 - auc: 0.8255 - val_loss: 0.5792 - val_tp: 66.0000 - val_fp: 13452.0000 - val_tn: 32030.0000 - val_fn: 21.0000 - val_accuracy: 0.7043 - val_precision: 0.0049 - val_recall: 0.7586 - val_auc: 0.7866
Epoch 2/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.6996 - tp: 13383.0000 - fp: 7208.0000 - tn: 13397.0000 - fn: 6972.0000 - accuracy: 0.6538 - precision: 0.6499 - recall: 0.6575 - auc: 0.7027 - val_loss: 0.5702 - val_tp: 76.0000 - val_fp: 12408.0000 - val_tn: 33074.0000 - val_fn: 11.0000 - val_accuracy: 0.7275 - val_precision: 0.0061 - val_recall: 0.8736 - val_auc: 0.9076
Epoch 3/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.5532 - tp: 15127.0000 - fp: 6665.0000 - tn: 14055.0000 - fn: 5113.0000 - accuracy: 0.7125 - precision: 0.6942 - recall: 0.7474 - auc: 0.7952 - val_loss: 0.5335 - val_tp: 79.0000 - val_fp: 9006.0000 - val_tn: 36476.0000 - val_fn: 8.0000 - val_accuracy: 0.8022 - val_precision: 0.0087 - val_recall: 0.9080 - val_auc: 0.9408
Epoch 4/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.4738 - tp: 16061.0000 - fp: 5669.0000 - tn: 14890.0000 - fn: 4340.0000 - accuracy: 0.7556 - precision: 0.7391 - recall: 0.7873 - auc: 0.8495 - val_loss: 0.4883 - val_tp: 78.0000 - val_fp: 5756.0000 - val_tn: 39726.0000 - val_fn: 9.0000 - val_accuracy: 0.8735 - val_precision: 0.0134 - val_recall: 0.8966 - val_auc: 0.9489
Epoch 5/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.4266 - tp: 16612.0000 - fp: 4719.0000 - tn: 15715.0000 - fn: 3914.0000 - accuracy: 0.7892 - precision: 0.7788 - recall: 0.8093 - auc: 0.8786 - val_loss: 0.4435 - val_tp: 78.0000 - val_fp: 3758.0000 - val_tn: 41724.0000 - val_fn: 9.0000 - val_accuracy: 0.9173 - val_precision: 0.0203 - val_recall: 0.8966 - val_auc: 0.9539
Epoch 6/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.3908 - tp: 16911.0000 - fp: 3861.0000 - tn: 16514.0000 - fn: 3674.0000 - accuracy: 0.8160 - precision: 0.8141 - recall: 0.8215 - auc: 0.8976 - val_loss: 0.4032 - val_tp: 79.0000 - val_fp: 2770.0000 - val_tn: 42712.0000 - val_fn: 8.0000 - val_accuracy: 0.9390 - val_precision: 0.0277 - val_recall: 0.9080 - val_auc: 0.9590
Epoch 7/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3664 - tp: 17049.0000 - fp: 3209.0000 - tn: 17179.0000 - fn: 3523.0000 - accuracy: 0.8356 - precision: 0.8416 - recall: 0.8287 - auc: 0.9108 - val_loss: 0.3682 - val_tp: 79.0000 - val_fp: 2119.0000 - val_tn: 43363.0000 - val_fn: 8.0000 - val_accuracy: 0.9533 - val_precision: 0.0359 - val_recall: 0.9080 - val_auc: 0.9634
Epoch 8/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.3467 - tp: 17100.0000 - fp: 2699.0000 - tn: 17686.0000 - fn: 3475.0000 - accuracy: 0.8493 - precision: 0.8637 - recall: 0.8311 - auc: 0.9193 - val_loss: 0.3373 - val_tp: 79.0000 - val_fp: 1753.0000 - val_tn: 43729.0000 - val_fn: 8.0000 - val_accuracy: 0.9614 - val_precision: 0.0431 - val_recall: 0.9080 - val_auc: 0.9675
Epoch 9/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.3285 - tp: 17043.0000 - fp: 2345.0000 - tn: 18228.0000 - fn: 3344.0000 - accuracy: 0.8611 - precision: 0.8790 - recall: 0.8360 - auc: 0.9271 - val_loss: 0.3104 - val_tp: 79.0000 - val_fp: 1495.0000 - val_tn: 43987.0000 - val_fn: 8.0000 - val_accuracy: 0.9670 - val_precision: 0.0502 - val_recall: 0.9080 - val_auc: 0.9702
Epoch 10/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.3094 - tp: 17322.0000 - fp: 2012.0000 - tn: 18405.0000 - fn: 3221.0000 - accuracy: 0.8722 - precision: 0.8959 - recall: 0.8432 - auc: 0.9361 - val_loss: 0.2865 - val_tp: 79.0000 - val_fp: 1332.0000 - val_tn: 44150.0000 - val_fn: 8.0000 - val_accuracy: 0.9706 - val_precision: 0.0560 - val_recall: 0.9080 - val_auc: 0.9721
Epoch 11/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2962 - tp: 17184.0000 - fp: 1757.0000 - tn: 18853.0000 - fn: 3166.0000 - accuracy: 0.8798 - precision: 0.9072 - recall: 0.8444 - auc: 0.9406 - val_loss: 0.2654 - val_tp: 79.0000 - val_fp: 1228.0000 - val_tn: 44254.0000 - val_fn: 8.0000 - val_accuracy: 0.9729 - val_precision: 0.0604 - val_recall: 0.9080 - val_auc: 0.9739
Epoch 12/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2835 - tp: 17373.0000 - fp: 1543.0000 - tn: 18909.0000 - fn: 3135.0000 - accuracy: 0.8858 - precision: 0.9184 - recall: 0.8471 - auc: 0.9458 - val_loss: 0.2469 - val_tp: 79.0000 - val_fp: 1155.0000 - val_tn: 44327.0000 - val_fn: 8.0000 - val_accuracy: 0.9745 - val_precision: 0.0640 - val_recall: 0.9080 - val_auc: 0.9759
Epoch 13/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2710 - tp: 17386.0000 - fp: 1395.0000 - tn: 19124.0000 - fn: 3055.0000 - accuracy: 0.8914 - precision: 0.9257 - recall: 0.8505 - auc: 0.9502 - val_loss: 0.2302 - val_tp: 79.0000 - val_fp: 1092.0000 - val_tn: 44390.0000 - val_fn: 8.0000 - val_accuracy: 0.9759 - val_precision: 0.0675 - val_recall: 0.9080 - val_auc: 0.9782
Epoch 14/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2618 - tp: 17336.0000 - fp: 1343.0000 - tn: 19296.0000 - fn: 2985.0000 - accuracy: 0.8943 - precision: 0.9281 - recall: 0.8531 - auc: 0.9541 - val_loss: 0.2156 - val_tp: 79.0000 - val_fp: 1053.0000 - val_tn: 44429.0000 - val_fn: 8.0000 - val_accuracy: 0.9767 - val_precision: 0.0698 - val_recall: 0.9080 - val_auc: 0.9797
Epoch 15/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2529 - tp: 17466.0000 - fp: 1154.0000 - tn: 19366.0000 - fn: 2974.0000 - accuracy: 0.8992 - precision: 0.9380 - recall: 0.8545 - auc: 0.9574 - val_loss: 0.2026 - val_tp: 79.0000 - val_fp: 1029.0000 - val_tn: 44453.0000 - val_fn: 8.0000 - val_accuracy: 0.9772 - val_precision: 0.0713 - val_recall: 0.9080 - val_auc: 0.9806
Epoch 16/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2456 - tp: 17579.0000 - fp: 1075.0000 - tn: 19322.0000 - fn: 2984.0000 - accuracy: 0.9009 - precision: 0.9424 - recall: 0.8549 - auc: 0.9590 - val_loss: 0.1923 - val_tp: 79.0000 - val_fp: 1017.0000 - val_tn: 44465.0000 - val_fn: 8.0000 - val_accuracy: 0.9775 - val_precision: 0.0721 - val_recall: 0.9080 - val_auc: 0.9813
Epoch 17/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2382 - tp: 17573.0000 - fp: 982.0000 - tn: 19540.0000 - fn: 2865.0000 - accuracy: 0.9061 - precision: 0.9471 - recall: 0.8598 - auc: 0.9620 - val_loss: 0.1828 - val_tp: 79.0000 - val_fp: 1005.0000 - val_tn: 44477.0000 - val_fn: 8.0000 - val_accuracy: 0.9778 - val_precision: 0.0729 - val_recall: 0.9080 - val_auc: 0.9819
Epoch 18/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2307 - tp: 17711.0000 - fp: 966.0000 - tn: 19448.0000 - fn: 2835.0000 - accuracy: 0.9072 - precision: 0.9483 - recall: 0.8620 - auc: 0.9644 - val_loss: 0.1736 - val_tp: 80.0000 - val_fp: 990.0000 - val_tn: 44492.0000 - val_fn: 7.0000 - val_accuracy: 0.9781 - val_precision: 0.0748 - val_recall: 0.9195 - val_auc: 0.9825
Epoch 19/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2280 - tp: 17732.0000 - fp: 952.0000 - tn: 19442.0000 - fn: 2834.0000 - accuracy: 0.9076 - precision: 0.9490 - recall: 0.8622 - auc: 0.9653 - val_loss: 0.1660 - val_tp: 80.0000 - val_fp: 974.0000 - val_tn: 44508.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0759 - val_recall: 0.9195 - val_auc: 0.9826
Epoch 20/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2224 - tp: 17725.0000 - fp: 939.0000 - tn: 19538.0000 - fn: 2758.0000 - accuracy: 0.9097 - precision: 0.9497 - recall: 0.8654 - auc: 0.9667 - val_loss: 0.1591 - val_tp: 80.0000 - val_fp: 962.0000 - val_tn: 44520.0000 - val_fn: 7.0000 - val_accuracy: 0.9787 - val_precision: 0.0768 - val_recall: 0.9195 - val_auc: 0.9831
Epoch 21/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2168 - tp: 17757.0000 - fp: 826.0000 - tn: 19618.0000 - fn: 2759.0000 - accuracy: 0.9125 - precision: 0.9556 - recall: 0.8655 - auc: 0.9689 - val_loss: 0.1531 - val_tp: 80.0000 - val_fp: 967.0000 - val_tn: 44515.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0764 - val_recall: 0.9195 - val_auc: 0.9831
Epoch 22/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.2112 - tp: 17833.0000 - fp: 883.0000 - tn: 19522.0000 - fn: 2722.0000 - accuracy: 0.9120 - precision: 0.9528 - recall: 0.8676 - auc: 0.9703 - val_loss: 0.1479 - val_tp: 80.0000 - val_fp: 975.0000 - val_tn: 44507.0000 - val_fn: 7.0000 - val_accuracy: 0.9785 - val_precision: 0.0758 - val_recall: 0.9195 - val_auc: 0.9832
Epoch 23/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2058 - tp: 17865.0000 - fp: 835.0000 - tn: 19580.0000 - fn: 2680.0000 - accuracy: 0.9142 - precision: 0.9553 - recall: 0.8696 - auc: 0.9723 - val_loss: 0.1427 - val_tp: 80.0000 - val_fp: 977.0000 - val_tn: 44505.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0757 - val_recall: 0.9195 - val_auc: 0.9834
Epoch 24/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2053 - tp: 17856.0000 - fp: 802.0000 - tn: 19599.0000 - fn: 2703.0000 - accuracy: 0.9144 - precision: 0.9570 - recall: 0.8685 - auc: 0.9727 - val_loss: 0.1375 - val_tp: 80.0000 - val_fp: 969.0000 - val_tn: 44513.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0763 - val_recall: 0.9195 - val_auc: 0.9833
Epoch 25/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2004 - tp: 17854.0000 - fp: 809.0000 - tn: 19690.0000 - fn: 2607.0000 - accuracy: 0.9166 - precision: 0.9567 - recall: 0.8726 - auc: 0.9740 - val_loss: 0.1331 - val_tp: 80.0000 - val_fp: 976.0000 - val_tn: 44506.0000 - val_fn: 7.0000 - val_accuracy: 0.9784 - val_precision: 0.0758 - val_recall: 0.9195 - val_auc: 0.9837
Epoch 26/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1991 - tp: 17857.0000 - fp: 793.0000 - tn: 19690.0000 - fn: 2620.0000 - accuracy: 0.9167 - precision: 0.9575 - recall: 0.8721 - auc: 0.9747 - val_loss: 0.1291 - val_tp: 80.0000 - val_fp: 968.0000 - val_tn: 44514.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0763 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 27/1000
20/20 [==============================] - 1s 40ms/step - loss: 0.1929 - tp: 17836.0000 - fp: 750.0000 - tn: 19833.0000 - fn: 2541.0000 - accuracy: 0.9197 - precision: 0.9596 - recall: 0.8753 - auc: 0.9760 - val_loss: 0.1252 - val_tp: 80.0000 - val_fp: 960.0000 - val_tn: 44522.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0769 - val_recall: 0.9195 - val_auc: 0.9839
Epoch 28/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.1935 - tp: 17776.0000 - fp: 753.0000 - tn: 19827.0000 - fn: 2604.0000 - accuracy: 0.9180 - precision: 0.9594 - recall: 0.8722 - auc: 0.9763 - val_loss: 0.1215 - val_tp: 80.0000 - val_fp: 946.0000 - val_tn: 44536.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0780 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 29/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1892 - tp: 17877.0000 - fp: 746.0000 - tn: 19791.0000 - fn: 2546.0000 - accuracy: 0.9196 - precision: 0.9599 - recall: 0.8753 - auc: 0.9773 - val_loss: 0.1183 - val_tp: 80.0000 - val_fp: 944.0000 - val_tn: 44538.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0781 - val_recall: 0.9195 - val_auc: 0.9840
Epoch 30/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1855 - tp: 18053.0000 - fp: 746.0000 - tn: 19673.0000 - fn: 2488.0000 - accuracy: 0.9210 - precision: 0.9603 - recall: 0.8789 - auc: 0.9779 - val_loss: 0.1157 - val_tp: 80.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 7.0000 - val_accuracy: 0.9790 - val_precision: 0.0777 - val_recall: 0.9195 - val_auc: 0.9835
Epoch 31/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1843 - tp: 18042.0000 - fp: 723.0000 - tn: 19656.0000 - fn: 2539.0000 - accuracy: 0.9204 - precision: 0.9615 - recall: 0.8766 - auc: 0.9783 - val_loss: 0.1137 - val_tp: 80.0000 - val_fp: 958.0000 - val_tn: 44524.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0771 - val_recall: 0.9195 - val_auc: 0.9836
Epoch 32/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1831 - tp: 17974.0000 - fp: 743.0000 - tn: 19741.0000 - fn: 2502.0000 - accuracy: 0.9208 - precision: 0.9603 - recall: 0.8778 - auc: 0.9789 - val_loss: 0.1112 - val_tp: 80.0000 - val_fp: 958.0000 - val_tn: 44524.0000 - val_fn: 7.0000 - val_accuracy: 0.9788 - val_precision: 0.0771 - val_recall: 0.9195 - val_auc: 0.9840
Epoch 33/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1805 - tp: 18172.0000 - fp: 775.0000 - tn: 19591.0000 - fn: 2422.0000 - accuracy: 0.9219 - precision: 0.9591 - recall: 0.8824 - auc: 0.9796 - val_loss: 0.1088 - val_tp: 81.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 6.0000 - val_accuracy: 0.9789 - val_precision: 0.0781 - val_recall: 0.9310 - val_auc: 0.9841
Epoch 34/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1749 - tp: 18125.0000 - fp: 715.0000 - tn: 19698.0000 - fn: 2422.0000 - accuracy: 0.9234 - precision: 0.9620 - recall: 0.8821 - auc: 0.9812 - val_loss: 0.1068 - val_tp: 81.0000 - val_fp: 964.0000 - val_tn: 44518.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0775 - val_recall: 0.9310 - val_auc: 0.9836
Epoch 35/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1769 - tp: 18135.0000 - fp: 715.0000 - tn: 19694.0000 - fn: 2416.0000 - accuracy: 0.9236 - precision: 0.9621 - recall: 0.8824 - auc: 0.9809 - val_loss: 0.1048 - val_tp: 81.0000 - val_fp: 978.0000 - val_tn: 44504.0000 - val_fn: 6.0000 - val_accuracy: 0.9784 - val_precision: 0.0765 - val_recall: 0.9310 - val_auc: 0.9838
Epoch 36/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1739 - tp: 18006.0000 - fp: 704.0000 - tn: 19827.0000 - fn: 2423.0000 - accuracy: 0.9237 - precision: 0.9624 - recall: 0.8814 - auc: 0.9814 - val_loss: 0.1029 - val_tp: 81.0000 - val_fp: 986.0000 - val_tn: 44496.0000 - val_fn: 6.0000 - val_accuracy: 0.9782 - val_precision: 0.0759 - val_recall: 0.9310 - val_auc: 0.9839
Epoch 37/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1687 - tp: 18002.0000 - fp: 660.0000 - tn: 19879.0000 - fn: 2419.0000 - accuracy: 0.9248 - precision: 0.9646 - recall: 0.8815 - auc: 0.9826 - val_loss: 0.1011 - val_tp: 81.0000 - val_fp: 984.0000 - val_tn: 44498.0000 - val_fn: 6.0000 - val_accuracy: 0.9783 - val_precision: 0.0761 - val_recall: 0.9310 - val_auc: 0.9841
Epoch 38/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1699 - tp: 17932.0000 - fp: 677.0000 - tn: 19986.0000 - fn: 2365.0000 - accuracy: 0.9257 - precision: 0.9636 - recall: 0.8835 - auc: 0.9825 - val_loss: 0.0995 - val_tp: 82.0000 - val_fp: 979.0000 - val_tn: 44503.0000 - val_fn: 5.0000 - val_accuracy: 0.9784 - val_precision: 0.0773 - val_recall: 0.9425 - val_auc: 0.9842
Epoch 39/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1676 - tp: 18086.0000 - fp: 736.0000 - tn: 19780.0000 - fn: 2358.0000 - accuracy: 0.9245 - precision: 0.9609 - recall: 0.8847 - auc: 0.9826 - val_loss: 0.0980 - val_tp: 82.0000 - val_fp: 975.0000 - val_tn: 44507.0000 - val_fn: 5.0000 - val_accuracy: 0.9785 - val_precision: 0.0776 - val_recall: 0.9425 - val_auc: 0.9844
Epoch 40/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1670 - tp: 18066.0000 - fp: 685.0000 - tn: 19868.0000 - fn: 2341.0000 - accuracy: 0.9261 - precision: 0.9635 - recall: 0.8853 - auc: 0.9832 - val_loss: 0.0964 - val_tp: 82.0000 - val_fp: 965.0000 - val_tn: 44517.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0783 - val_recall: 0.9425 - val_auc: 0.9845
Epoch 41/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1640 - tp: 17950.0000 - fp: 645.0000 - tn: 19995.0000 - fn: 2370.0000 - accuracy: 0.9264 - precision: 0.9653 - recall: 0.8834 - auc: 0.9839 - val_loss: 0.0950 - val_tp: 82.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0790 - val_recall: 0.9425 - val_auc: 0.9835
Epoch 42/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1641 - tp: 18083.0000 - fp: 665.0000 - tn: 19842.0000 - fn: 2370.0000 - accuracy: 0.9259 - precision: 0.9645 - recall: 0.8841 - auc: 0.9839 - val_loss: 0.0938 - val_tp: 82.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0795 - val_recall: 0.9425 - val_auc: 0.9837
Epoch 43/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1600 - tp: 18012.0000 - fp: 684.0000 - tn: 19970.0000 - fn: 2294.0000 - accuracy: 0.9273 - precision: 0.9634 - recall: 0.8870 - auc: 0.9845 - val_loss: 0.0925 - val_tp: 82.0000 - val_fp: 949.0000 - val_tn: 44533.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0795 - val_recall: 0.9425 - val_auc: 0.9837
Epoch 44/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1597 - tp: 18346.0000 - fp: 657.0000 - tn: 19657.0000 - fn: 2300.0000 - accuracy: 0.9278 - precision: 0.9654 - recall: 0.8886 - auc: 0.9847 - val_loss: 0.0919 - val_tp: 82.0000 - val_fp: 955.0000 - val_tn: 44527.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0791 - val_recall: 0.9425 - val_auc: 0.9838
Epoch 45/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1607 - tp: 18109.0000 - fp: 726.0000 - tn: 19836.0000 - fn: 2289.0000 - accuracy: 0.9264 - precision: 0.9615 - recall: 0.8878 - auc: 0.9846 - val_loss: 0.0908 - val_tp: 82.0000 - val_fp: 948.0000 - val_tn: 44534.0000 - val_fn: 5.0000 - val_accuracy: 0.9791 - val_precision: 0.0796 - val_recall: 0.9425 - val_auc: 0.9839
Epoch 46/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1581 - tp: 18192.0000 - fp: 650.0000 - tn: 19833.0000 - fn: 2285.0000 - accuracy: 0.9283 - precision: 0.9655 - recall: 0.8884 - auc: 0.9849 - val_loss: 0.0902 - val_tp: 82.0000 - val_fp: 955.0000 - val_tn: 44527.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0791 - val_recall: 0.9425 - val_auc: 0.9839
Epoch 47/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1579 - tp: 18301.0000 - fp: 676.0000 - tn: 19760.0000 - fn: 2223.0000 - accuracy: 0.9292 - precision: 0.9644 - recall: 0.8917 - auc: 0.9853 - val_loss: 0.0892 - val_tp: 82.0000 - val_fp: 956.0000 - val_tn: 44526.0000 - val_fn: 5.0000 - val_accuracy: 0.9789 - val_precision: 0.0790 - val_recall: 0.9425 - val_auc: 0.9840
Epoch 48/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1503 - tp: 18172.0000 - fp: 593.0000 - tn: 19959.0000 - fn: 2236.0000 - accuracy: 0.9309 - precision: 0.9684 - recall: 0.8904 - auc: 0.9867 - val_loss: 0.0887 - val_tp: 82.0000 - val_fp: 970.0000 - val_tn: 44512.0000 - val_fn: 5.0000 - val_accuracy: 0.9786 - val_precision: 0.0779 - val_recall: 0.9425 - val_auc: 0.9840
Epoch 49/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1572 - tp: 18217.0000 - fp: 750.0000 - tn: 19709.0000 - fn: 2284.0000 - accuracy: 0.9259 - precision: 0.9605 - recall: 0.8886 - auc: 0.9852 - val_loss: 0.0876 - val_tp: 82.0000 - val_fp: 964.0000 - val_tn: 44518.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.0784 - val_recall: 0.9425 - val_auc: 0.9841
Epoch 50/1000
20/20 [==============================] - ETA: 0s - loss: 0.1529 - tp: 18230.0000 - fp: 696.0000 - tn: 19874.0000 - fn: 2160.0000 - accuracy: 0.9303 - precision: 0.9632 - recall: 0.8941 - auc: 0.9860Restoring model weights from the end of the best epoch.
20/20 [==============================] - 0s 23ms/step - loss: 0.1529 - tp: 18230.0000 - fp: 696.0000 - tn: 19874.0000 - fn: 2160.0000 - accuracy: 0.9303 - precision: 0.9632 - recall: 0.8941 - auc: 0.9860 - val_loss: 0.0860 - val_tp: 82.0000 - val_fp: 941.0000 - val_tn: 44541.0000 - val_fn: 5.0000 - val_accuracy: 0.9792 - val_precision: 0.0802 - val_recall: 0.9425 - val_auc: 0.9843
Epoch 00050: early stopping

トレーニング履歴を再確認

 plot_metrics(resampled_history)
 

png

指標を評価する

 train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
 
 resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
 
loss :  0.09607589244842529
tp :  84.0
fp :  1195.0
tn :  55676.0
fn :  7.0
accuracy :  0.9788982272148132
precision :  0.06567630916833878
recall :  0.9230769276618958
auc :  0.9697299599647522

Legitimate Transactions Detected (True Negatives):  55676
Legitimate Transactions Incorrectly Detected (False Positives):  1195
Fraudulent Transactions Missed (False Negatives):  7
Fraudulent Transactions Detected (True Positives):  84
Total Fraudulent Transactions:  91

png

ROCをプロットする

 plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled,  color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled,  color=colors[2], linestyle='--')
plt.legend(loc='lower right')
 
<matplotlib.legend.Legend at 0x7fa4bc66c9b0>

png

このチュートリアルを問題に適用する

学習するサンプルが非常に少ないため、不均衡なデータ分類は本質的に困難なタスクです。常に最初にデータから始めて、できるだけ多くのサンプルを収集し、モデルがマイノリティクラスを最大限に活用できるように、関連する機能を十分に検討するよう最善を尽くす必要があります。ある時点で、モデルは必要な結果を改善および生成するのに苦労する可能性があるため、問題のコンテキストと、さまざまなタイプのエラー間のトレードオフを覚えておくことが重要です。