Klasyfikacja na niezrównoważonych danych

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

W tym samouczku pokazano, jak sklasyfikować wysoce niezrównoważony zestaw danych, w którym liczba przykładów w jednej klasie znacznie przewyższa liczbę przykładów w innej. Będziesz pracować z zestawem danych do wykrywania oszustw kart kredytowych hostowanym na Kaggle. Celem jest wykrycie zaledwie 492 nieuczciwych transakcji z łącznej liczby 284.807 transakcji. Użyjesz Keras do zdefiniowania wag modelu i klasy, aby pomóc modelowi uczyć się na podstawie niezrównoważonych danych. .

Ten samouczek zawiera kompletny kod do:

  • Załaduj plik CSV za pomocą Pandy.
  • Twórz zestawy treningowe, walidacyjne i testowe.
  • Zdefiniuj i wytrenuj model za pomocą Keras (w tym ustawienie wag klas).
  • Oceń model za pomocą różnych metryk (w tym precyzji i przypominania).
  • Wypróbuj popularne techniki radzenia sobie z niezrównoważonymi danymi, takie jak:
    • Ważenie klas
    • Nadpróbkowanie

Ustawiać

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

Przetwarzanie i eksploracja danych

Pobierz zestaw danych dotyczących oszustw związanych z kartami kredytowymi Kaggle

Pandas to biblioteka Pythona z wieloma pomocnymi narzędziami do ładowania i pracy z danymi strukturalnymi. Może być używany do pobierania plików CSV do Pandas DataFrame .

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

Sprawdź brak równowagi na etykiecie klasy

Spójrzmy na nierównowagę zbioru danych:

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)

To pokazuje niewielką część próbek pozytywnych.

Wyczyść, podziel i znormalizuj dane

Surowe dane mają kilka problemów. Po pierwsze kolumny Time i Amount są zbyt zmienne, aby można było ich użyć bezpośrednio. Usuń kolumnę Time (ponieważ nie jest jasne, co to znaczy) i weź dziennik kolumny Amount aby zmniejszyć jej zakres.

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)

Podziel zbiór danych na zestawy do pociągów, walidacji i testów. Zestaw walidacyjny jest używany podczas dopasowywania modelu do oceny strat i wszelkich metryk, jednak model nie jest dopasowany do tych danych. Zestaw testowy jest całkowicie nieużywany w fazie uczenia i jest używany dopiero na końcu do oceny, jak dobrze model uogólnia się na nowe dane. Jest to szczególnie ważne w przypadku niezrównoważonych zestawów danych, w których nadmierne dopasowanie jest poważnym problemem z powodu braku danych treningowych.

# Use a utility from sklearn to split and shuffle your dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

Normalizuj funkcje wejściowe za pomocą sklearn StandardScaler. To ustawi średnią na 0, a odchylenie standardowe na 1.

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

Spójrz na dystrybucję danych

Następnie porównaj rozkłady pozytywnych i negatywnych przykładów w kilku cechach. Dobre pytania do zadania sobie w tym momencie to:

  • Czy te dystrybucje mają sens?
    • Tak. Znormalizowałeś dane wejściowe i są one w większości skoncentrowane w zakresie +/- 2 .
  • Czy widzisz różnicę między dystrybucjami?
    • Tak, pozytywne przykłady zawierają znacznie wyższy wskaźnik wartości ekstremalnych.
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/kbuilder/.local/lib/python3.7/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning
/home/kbuilder/.local/lib/python3.7/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning

png

png

Zdefiniuj model i metryki

Zdefiniuj funkcję, która tworzy prostą sieć neuronową z gęsto połączoną warstwą ukrytą, warstwą porzucania w celu zmniejszenia nadmiernego dopasowania oraz sigmoidalną warstwą wyjściową, która zwraca prawdopodobieństwo oszustwa transakcji:

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
      keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

def make_model(metrics=METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(lr=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

Zrozumienie przydatnych wskaźników

Zwróć uwagę, że istnieje kilka zdefiniowanych powyżej metryk, które mogą być obliczane przez model, które będą pomocne podczas oceny wydajności.

  • Fałszywie ujemne i fałszywie dodatnie to próbki, które zostały nieprawidłowo sklasyfikowane
  • Prawdziwie negatywne i prawdziwie pozytywne to próbki, które zostały prawidłowo sklasyfikowane
  • Dokładność to procent przykładów poprawnie sklasyfikowanych > $\frac{\text{próbki prawdziwe} }{\text{próbki ogółem} }$
  • Dokładność to procent przewidywanych wyników pozytywnych, które zostały poprawnie sklasyfikowane > $\frac{\text{prawdziwe pozytywne} }{\text{prawdziwie pozytywne + fałszywe alarmy} }$
  • Przypomnij to odsetek rzeczywistych wyników pozytywnych, które zostały poprawnie sklasyfikowane > $\frac{\text{prawdziwe pozytywne} }{\text{prawdziwe pozytywne + fałszywe negatywne} }$
  • AUC odnosi się do obszaru pod krzywą krzywej charakterystyki pracy odbiornika (ROC-AUC). Ta metryka jest równa prawdopodobieństwu, że klasyfikator oceni losową próbkę dodatnią wyżej niż losową próbkę ujemną.
  • AUPRC odnosi się do obszaru pod krzywą krzywej precyzyjnego przywracania . Ta metryka oblicza pary precyzja-odwołanie dla różnych progów prawdopodobieństwa.

Czytaj więcej:

Model podstawowy

Zbuduj model

Teraz utwórz i wytrenuj swój model za pomocą funkcji, która została zdefiniowana wcześniej. Zauważ, że model jest dopasowany przy użyciu większej niż domyślna wielkości partii 2048, jest to ważne, aby upewnić się, że każda partia ma przyzwoitą szansę na zawieranie kilku pozytywnych próbek. Jeśli wielkość partii byłaby zbyt mała, prawdopodobnie nie mieliby żadnych oszukańczych transakcji, z których mogliby się uczyć.

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_prc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")

Przetestuj model:

model.predict(train_features[:10])
array([[0.88874197],
       [0.86864525],
       [0.7491154 ],
       [0.8249989 ],
       [0.8674842 ],
       [0.93236715],
       [0.8916911 ],
       [0.925606  ],
       [0.8452092 ],
       [0.9233874 ]], dtype=float32)

Opcjonalnie: ustaw prawidłowe początkowe odchylenie.

Te wstępne domysły nie są świetne. Wiesz, że zbiór danych jest niezrównoważony. Ustaw odchylenie warstwy wyjściowej, aby to odzwierciedlić (patrz: Przepis na uczenie sieci neuronowych: "init well" ). Może to pomóc w początkowej konwergencji.

Przy domyślnej inicjalizacji odchylenia strata powinna wynosić około math.log(2) = 0.69314

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 1.9253

Prawidłowe odchylenie do ustawienia można wyprowadzić z:

$$ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) $$
$$ b_0 = -log_e(1/p_0 - 1) $$
$$ b_0 = log_e(pos/neg)$$
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

Ustaw to jako początkowe odchylenie, a model da znacznie bardziej rozsądne początkowe domysły.

Powinno być blisko: pos/total = 0.0018

model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
array([[0.00436974],
       [0.00134233],
       [0.00460862],
       [0.00144072],
       [0.00280632],
       [0.00290238],
       [0.00553038],
       [0.00227821],
       [0.00264283],
       [0.00139598]], dtype=float32)

Przy tej inicjalizacji początkowa strata powinna wynosić w przybliżeniu:

$$-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317$$
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0155

Ta początkowa strata jest około 50 razy mniejsza niż w przypadku naiwnej inicjalizacji.

W ten sposób model nie musi spędzić kilku pierwszych epok na uczeniu się, że pozytywne przykłady są mało prawdopodobne. Ułatwia to również odczytywanie wykresów strat podczas treningu.

Sprawdź początkowe wagi

Aby różne przebiegi szkolenia były bardziej porównywalne, zachowaj wagi tego modelu początkowego w pliku punktu kontrolnego i załaduj je do każdego modelu przed szkoleniem:

initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)

Potwierdź, że poprawka stronniczości pomaga

Zanim przejdziesz dalej, szybko potwierdź, że ostrożna inicjalizacja stronniczości rzeczywiście pomogła.

Trenuj model przez 20 epok, z tą staranną inicjalizacją i bez niej, i porównaj straty:

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale on y-axis to show the wide range of values.
  plt.semilogy(history.epoch, history.history['loss'],
               color=colors[n], label='Train ' + label)
  plt.semilogy(history.epoch, history.history['val_loss'],
               color=colors[n], label='Val ' + label,
               linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

png

Powyższy rysunek wyjaśnia: W przypadku utraty walidacji w tym problemie ta ostrożna inicjalizacja daje wyraźną przewagę.

Trenuj modelkę

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 3s 15ms/step - loss: 0.0119 - tp: 81.0000 - fp: 40.0000 - tn: 227422.0000 - fn: 302.0000 - accuracy: 0.9985 - precision: 0.6694 - recall: 0.2115 - auc: 0.7766 - prc: 0.2450 - val_loss: 0.0052 - val_tp: 16.0000 - val_fp: 3.0000 - val_tn: 45504.0000 - val_fn: 46.0000 - val_accuracy: 0.9989 - val_precision: 0.8421 - val_recall: 0.2581 - val_auc: 0.9017 - val_prc: 0.6837
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0073 - tp: 119.0000 - fp: 23.0000 - tn: 181932.0000 - fn: 202.0000 - accuracy: 0.9988 - precision: 0.8380 - recall: 0.3707 - auc: 0.8848 - prc: 0.5190 - val_loss: 0.0041 - val_tp: 36.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 26.0000 - val_accuracy: 0.9993 - val_precision: 0.8780 - val_recall: 0.5806 - val_auc: 0.9109 - val_prc: 0.7107
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0064 - tp: 155.0000 - fp: 36.0000 - tn: 181919.0000 - fn: 166.0000 - accuracy: 0.9989 - precision: 0.8115 - recall: 0.4829 - auc: 0.9191 - prc: 0.5920 - val_loss: 0.0037 - val_tp: 37.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8810 - val_recall: 0.5968 - val_auc: 0.9111 - val_prc: 0.7402
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0056 - tp: 171.0000 - fp: 34.0000 - tn: 181921.0000 - fn: 150.0000 - accuracy: 0.9990 - precision: 0.8341 - recall: 0.5327 - auc: 0.9166 - prc: 0.6295 - val_loss: 0.0035 - val_tp: 40.0000 - val_fp: 4.0000 - val_tn: 45503.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9091 - val_recall: 0.6452 - val_auc: 0.9191 - val_prc: 0.7474
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0054 - tp: 171.0000 - fp: 37.0000 - tn: 181918.0000 - fn: 150.0000 - accuracy: 0.9990 - precision: 0.8221 - recall: 0.5327 - auc: 0.9235 - prc: 0.6673 - val_loss: 0.0033 - val_tp: 41.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8913 - val_recall: 0.6613 - val_auc: 0.9191 - val_prc: 0.7619
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 184.0000 - fp: 38.0000 - tn: 181917.0000 - fn: 137.0000 - accuracy: 0.9990 - precision: 0.8288 - recall: 0.5732 - auc: 0.9254 - prc: 0.6839 - val_loss: 0.0031 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9191 - val_prc: 0.7643
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0048 - tp: 177.0000 - fp: 35.0000 - tn: 181920.0000 - fn: 144.0000 - accuracy: 0.9990 - precision: 0.8349 - recall: 0.5514 - auc: 0.9365 - prc: 0.6994 - val_loss: 0.0030 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9190 - val_prc: 0.7686
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 184.0000 - fp: 43.0000 - tn: 181912.0000 - fn: 137.0000 - accuracy: 0.9990 - precision: 0.8106 - recall: 0.5732 - auc: 0.9224 - prc: 0.6641 - val_loss: 0.0030 - val_tp: 40.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.8889 - val_recall: 0.6452 - val_auc: 0.9190 - val_prc: 0.7767
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 177.0000 - fp: 34.0000 - tn: 181921.0000 - fn: 144.0000 - accuracy: 0.9990 - precision: 0.8389 - recall: 0.5514 - auc: 0.9271 - prc: 0.6757 - val_loss: 0.0029 - val_tp: 44.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.8980 - val_recall: 0.7097 - val_auc: 0.9190 - val_prc: 0.7805
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 186.0000 - fp: 36.0000 - tn: 181919.0000 - fn: 135.0000 - accuracy: 0.9991 - precision: 0.8378 - recall: 0.5794 - auc: 0.9428 - prc: 0.7279 - val_loss: 0.0028 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9190 - val_prc: 0.7817
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0042 - tp: 198.0000 - fp: 37.0000 - tn: 181918.0000 - fn: 123.0000 - accuracy: 0.9991 - precision: 0.8426 - recall: 0.6168 - auc: 0.9321 - prc: 0.7221 - val_loss: 0.0028 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9190 - val_prc: 0.7824
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 200.0000 - fp: 37.0000 - tn: 181918.0000 - fn: 121.0000 - accuracy: 0.9991 - precision: 0.8439 - recall: 0.6231 - auc: 0.9368 - prc: 0.7310 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9190 - val_prc: 0.7851
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 206.0000 - fp: 33.0000 - tn: 181922.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8619 - recall: 0.6417 - auc: 0.9353 - prc: 0.7336 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9190 - val_prc: 0.7872
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 197.0000 - fp: 35.0000 - tn: 181920.0000 - fn: 124.0000 - accuracy: 0.9991 - precision: 0.8491 - recall: 0.6137 - auc: 0.9322 - prc: 0.7303 - val_loss: 0.0027 - val_tp: 46.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9020 - val_recall: 0.7419 - val_auc: 0.9190 - val_prc: 0.7869
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 194.0000 - fp: 35.0000 - tn: 181920.0000 - fn: 127.0000 - accuracy: 0.9991 - precision: 0.8472 - recall: 0.6044 - auc: 0.9259 - prc: 0.7179 - val_loss: 0.0027 - val_tp: 48.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.9057 - val_recall: 0.7742 - val_auc: 0.9190 - val_prc: 0.7873
Epoch 16/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0040 - tp: 190.0000 - fp: 36.0000 - tn: 181919.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8407 - recall: 0.5919 - auc: 0.9384 - prc: 0.7206 - val_loss: 0.0027 - val_tp: 49.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 13.0000 - val_accuracy: 0.9996 - val_precision: 0.9074 - val_recall: 0.7903 - val_auc: 0.9190 - val_prc: 0.7881
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0041 - tp: 194.0000 - fp: 31.0000 - tn: 181924.0000 - fn: 127.0000 - accuracy: 0.9991 - precision: 0.8622 - recall: 0.6044 - auc: 0.9322 - prc: 0.7257 - val_loss: 0.0027 - val_tp: 49.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 13.0000 - val_accuracy: 0.9996 - val_precision: 0.9074 - val_recall: 0.7903 - val_auc: 0.9190 - val_prc: 0.7894
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 206.0000 - fp: 37.0000 - tn: 181918.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8477 - recall: 0.6417 - auc: 0.9431 - prc: 0.7428 - val_loss: 0.0027 - val_tp: 49.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 13.0000 - val_accuracy: 0.9996 - val_precision: 0.8909 - val_recall: 0.7903 - val_auc: 0.9190 - val_prc: 0.7866
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0039 - tp: 198.0000 - fp: 39.0000 - tn: 181916.0000 - fn: 123.0000 - accuracy: 0.9991 - precision: 0.8354 - recall: 0.6168 - auc: 0.9463 - prc: 0.7405 - val_loss: 0.0027 - val_tp: 48.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.8889 - val_recall: 0.7742 - val_auc: 0.9190 - val_prc: 0.7864
Epoch 20/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 206.0000 - fp: 38.0000 - tn: 181917.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8443 - recall: 0.6417 - auc: 0.9431 - prc: 0.7465 - val_loss: 0.0027 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.8868 - val_recall: 0.7581 - val_auc: 0.9191 - val_prc: 0.7876
Epoch 21/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 197.0000 - fp: 42.0000 - tn: 181913.0000 - fn: 124.0000 - accuracy: 0.9991 - precision: 0.8243 - recall: 0.6137 - auc: 0.9431 - prc: 0.7496 - val_loss: 0.0027 - val_tp: 48.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.9057 - val_recall: 0.7742 - val_auc: 0.9191 - val_prc: 0.7875
Epoch 22/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0038 - tp: 191.0000 - fp: 39.0000 - tn: 181916.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8304 - recall: 0.5950 - auc: 0.9416 - prc: 0.7384 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 4.0000 - val_tn: 45503.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9184 - val_recall: 0.7258 - val_auc: 0.9191 - val_prc: 0.7916
Epoch 23/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0035 - tp: 209.0000 - fp: 30.0000 - tn: 181925.0000 - fn: 112.0000 - accuracy: 0.9992 - precision: 0.8745 - recall: 0.6511 - auc: 0.9401 - prc: 0.7721 - val_loss: 0.0027 - val_tp: 50.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 12.0000 - val_accuracy: 0.9996 - val_precision: 0.8929 - val_recall: 0.8065 - val_auc: 0.9191 - val_prc: 0.7875
Epoch 24/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0036 - tp: 215.0000 - fp: 33.0000 - tn: 181922.0000 - fn: 106.0000 - accuracy: 0.9992 - precision: 0.8669 - recall: 0.6698 - auc: 0.9416 - prc: 0.7693 - val_loss: 0.0027 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.8868 - val_recall: 0.7581 - val_auc: 0.9191 - val_prc: 0.7896
Epoch 25/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0036 - tp: 209.0000 - fp: 31.0000 - tn: 181924.0000 - fn: 112.0000 - accuracy: 0.9992 - precision: 0.8708 - recall: 0.6511 - auc: 0.9401 - prc: 0.7703 - val_loss: 0.0027 - val_tp: 46.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9020 - val_recall: 0.7419 - val_auc: 0.9191 - val_prc: 0.7908
Epoch 26/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0034 - tp: 208.0000 - fp: 31.0000 - tn: 181924.0000 - fn: 113.0000 - accuracy: 0.9992 - precision: 0.8703 - recall: 0.6480 - auc: 0.9401 - prc: 0.7746 - val_loss: 0.0027 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.8868 - val_recall: 0.7581 - val_auc: 0.9191 - val_prc: 0.7916
Epoch 27/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0034 - tp: 219.0000 - fp: 27.0000 - tn: 181928.0000 - fn: 102.0000 - accuracy: 0.9993 - precision: 0.8902 - recall: 0.6822 - auc: 0.9417 - prc: 0.7733 - val_loss: 0.0027 - val_tp: 49.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 13.0000 - val_accuracy: 0.9996 - val_precision: 0.8909 - val_recall: 0.7903 - val_auc: 0.9191 - val_prc: 0.7902
Epoch 28/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0035 - tp: 214.0000 - fp: 29.0000 - tn: 181926.0000 - fn: 107.0000 - accuracy: 0.9993 - precision: 0.8807 - recall: 0.6667 - auc: 0.9495 - prc: 0.7681 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9191 - val_prc: 0.7915
Epoch 29/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0037 - tp: 215.0000 - fp: 35.0000 - tn: 181920.0000 - fn: 106.0000 - accuracy: 0.9992 - precision: 0.8600 - recall: 0.6698 - auc: 0.9417 - prc: 0.7392 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9191 - val_prc: 0.7886
Epoch 30/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0036 - tp: 201.0000 - fp: 28.0000 - tn: 181927.0000 - fn: 120.0000 - accuracy: 0.9992 - precision: 0.8777 - recall: 0.6262 - auc: 0.9400 - prc: 0.7576 - val_loss: 0.0027 - val_tp: 49.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 13.0000 - val_accuracy: 0.9996 - val_precision: 0.8909 - val_recall: 0.7903 - val_auc: 0.9190 - val_prc: 0.7870
Epoch 31/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0035 - tp: 209.0000 - fp: 29.0000 - tn: 181926.0000 - fn: 112.0000 - accuracy: 0.9992 - precision: 0.8782 - recall: 0.6511 - auc: 0.9432 - prc: 0.7774 - val_loss: 0.0027 - val_tp: 48.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.8889 - val_recall: 0.7742 - val_auc: 0.9191 - val_prc: 0.7903
Epoch 32/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0036 - tp: 210.0000 - fp: 38.0000 - tn: 181917.0000 - fn: 111.0000 - accuracy: 0.9992 - precision: 0.8468 - recall: 0.6542 - auc: 0.9479 - prc: 0.7673 - val_loss: 0.0027 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.8868 - val_recall: 0.7581 - val_auc: 0.9191 - val_prc: 0.7896
Epoch 33/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0034 - tp: 217.0000 - fp: 33.0000 - tn: 181922.0000 - fn: 104.0000 - accuracy: 0.9992 - precision: 0.8680 - recall: 0.6760 - auc: 0.9479 - prc: 0.7892 - val_loss: 0.0027 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.8868 - val_recall: 0.7581 - val_auc: 0.9191 - val_prc: 0.7883
Epoch 34/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0035 - tp: 219.0000 - fp: 24.0000 - tn: 181931.0000 - fn: 102.0000 - accuracy: 0.9993 - precision: 0.9012 - recall: 0.6822 - auc: 0.9385 - prc: 0.7714 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9191 - val_prc: 0.7878
Epoch 35/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0033 - tp: 211.0000 - fp: 31.0000 - tn: 181924.0000 - fn: 110.0000 - accuracy: 0.9992 - precision: 0.8719 - recall: 0.6573 - auc: 0.9479 - prc: 0.7827 - val_loss: 0.0027 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.8868 - val_recall: 0.7581 - val_auc: 0.9191 - val_prc: 0.7867
Epoch 36/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0033 - tp: 216.0000 - fp: 34.0000 - tn: 181921.0000 - fn: 105.0000 - accuracy: 0.9992 - precision: 0.8640 - recall: 0.6729 - auc: 0.9464 - prc: 0.7880 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 5.0000 - val_tn: 45502.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9000 - val_recall: 0.7258 - val_auc: 0.9190 - val_prc: 0.7862
Restoring model weights from the end of the best epoch.
Epoch 00036: early stopping

Sprawdź historię szkoleń

W tej sekcji utworzysz wykresy dokładności i straty modelu w zbiorze uczącym i walidacyjnym. Są one przydatne do sprawdzenia, czy nie ma zbyt dużego dopasowania, o którym możesz dowiedzieć się więcej w samouczku Overfit i underfit .

Dodatkowo możesz tworzyć te wykresy dla dowolnych metryk utworzonych powyżej. Jako przykład podano fałszywe negatywy.

def plot_metrics(history):
  metrics = ['loss', 'prc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

png

Oceń dane

Możesz użyć macierzy pomyłek, aby podsumować etykiety rzeczywiste i przewidywane, gdzie oś X to etykieta przewidywana, a oś Y to etykieta rzeczywista:

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

Oceń swój model na testowym zbiorze danych i wyświetl wyniki dla metryk utworzonych powyżej:

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.003672640072181821
tp :  80.0
fp :  8.0
tn :  56845.0
fn :  29.0
accuracy :  0.9993504285812378
precision :  0.9090909361839294
recall :  0.7339449524879456
auc :  0.9263221025466919
prc :  0.8151081800460815

Legitimate Transactions Detected (True Negatives):  56845
Legitimate Transactions Incorrectly Detected (False Positives):  8
Fraudulent Transactions Missed (False Negatives):  29
Fraudulent Transactions Detected (True Positives):  80
Total Fraudulent Transactions:  109

png

Gdyby model przewidział wszystko idealnie, byłaby to macierz diagonalna, w której wartości poza główną przekątną, wskazujące na nieprawidłowe przewidywania, byłyby równe zeru. W tym przypadku macierz pokazuje, że masz stosunkowo mało fałszywych alarmów, co oznacza, że ​​było stosunkowo niewiele legalnych transakcji, które zostały nieprawidłowo oznaczone. Jednak prawdopodobnie chciałbyś mieć jeszcze mniej wyników fałszywie negatywnych, pomimo kosztów zwiększenia liczby fałszywych trafień. Ten kompromis może być lepszy, ponieważ fałszywe negatywy umożliwiłyby przeprowadzenie nieuczciwych transakcji, podczas gdy fałszywe alarmy mogą spowodować wysłanie wiadomości e-mail do klienta z prośbą o zweryfikowanie aktywności karty.

Wykreśl ROC

Teraz wykreśl ROC . Ten wykres jest przydatny, ponieważ pokazuje na pierwszy rzut oka zakres wydajności, jaki model może osiągnąć, po prostu dostrajając próg wyjściowy.

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fb90c3ccdd0>

png

Wykreśl AUPRC

Teraz wykreśl AUPRC . Obszar pod interpolowaną krzywą precyzja-odwołanie, uzyskany przez wykreślenie (odwołanie, precyzja) punktów dla różnych wartości progu klasyfikacji. W zależności od sposobu obliczenia, PR AUC może odpowiadać średniej precyzji modelu.

def plot_prc(name, labels, predictions, **kwargs):
    precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)

    plt.plot(precision, recall, label=name, linewidth=2, **kwargs)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fb90c2a7e90>

png

Wygląda na to, że precyzja jest stosunkowo wysoka, ale przywołanie i obszar pod krzywą ROC (AUC) nie są tak wysokie, jak byś chciał. Klasyfikatory często napotykają wyzwania, próbując zmaksymalizować zarówno precyzję, jak i przywołanie, co jest szczególnie ważne podczas pracy z niezrównoważonymi zestawami danych. Ważne jest, aby rozważyć koszty różnego rodzaju błędów w kontekście problemu, na którym Ci zależy. W tym przykładzie wynik fałszywie negatywny (przeoczenie transakcji oszukańczej) może wiązać się z kosztami finansowymi, natomiast wynik fałszywie pozytywny (transakcja jest nieprawidłowo oznaczona jako fałszywa) może zmniejszyć zadowolenie użytkownika.

Wagi klas

Oblicz wagi klas

Celem jest identyfikacja fałszywych transakcji, ale nie masz zbyt wielu pozytywnych próbek, z którymi możesz pracować, więc chciałbyś, aby klasyfikator mocno ważył kilka dostępnych przykładów. Możesz to zrobić, przekazując wagi Keras dla każdej klasy za pomocą parametru. Spowoduje to, że model „zwróci większą uwagę” na przykłady z niedostatecznie reprezentowanej klasy.

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

Trenuj model z wagami klasowymi

Teraz spróbuj ponownie wytrenować i ocenić model z wagami klas, aby zobaczyć, jak wpływa to na prognozy.

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
90/90 [==============================] - 3s 15ms/step - loss: 2.1749 - tp: 134.0000 - fp: 159.0000 - tn: 238649.0000 - fn: 296.0000 - accuracy: 0.9981 - precision: 0.4573 - recall: 0.3116 - auc: 0.7937 - prc: 0.2873 - val_loss: 0.0077 - val_tp: 41.0000 - val_fp: 6.0000 - val_tn: 45501.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8723 - val_recall: 0.6613 - val_auc: 0.9251 - val_prc: 0.6597
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.7878 - tp: 196.0000 - fp: 507.0000 - tn: 181448.0000 - fn: 125.0000 - accuracy: 0.9965 - precision: 0.2788 - recall: 0.6106 - auc: 0.8995 - prc: 0.4334 - val_loss: 0.0111 - val_tp: 49.0000 - val_fp: 15.0000 - val_tn: 45492.0000 - val_fn: 13.0000 - val_accuracy: 0.9994 - val_precision: 0.7656 - val_recall: 0.7903 - val_auc: 0.9516 - val_prc: 0.7065
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4861 - tp: 242.0000 - fp: 908.0000 - tn: 181047.0000 - fn: 79.0000 - accuracy: 0.9946 - precision: 0.2104 - recall: 0.7539 - auc: 0.9341 - prc: 0.5350 - val_loss: 0.0150 - val_tp: 50.0000 - val_fp: 34.0000 - val_tn: 45473.0000 - val_fn: 12.0000 - val_accuracy: 0.9990 - val_precision: 0.5952 - val_recall: 0.8065 - val_auc: 0.9577 - val_prc: 0.7125
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3944 - tp: 262.0000 - fp: 1604.0000 - tn: 180351.0000 - fn: 59.0000 - accuracy: 0.9909 - precision: 0.1404 - recall: 0.8162 - auc: 0.9392 - prc: 0.5218 - val_loss: 0.0214 - val_tp: 50.0000 - val_fp: 76.0000 - val_tn: 45431.0000 - val_fn: 12.0000 - val_accuracy: 0.9981 - val_precision: 0.3968 - val_recall: 0.8065 - val_auc: 0.9679 - val_prc: 0.7281
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3370 - tp: 262.0000 - fp: 2487.0000 - tn: 179468.0000 - fn: 59.0000 - accuracy: 0.9860 - precision: 0.0953 - recall: 0.8162 - auc: 0.9527 - prc: 0.4744 - val_loss: 0.0300 - val_tp: 51.0000 - val_fp: 191.0000 - val_tn: 45316.0000 - val_fn: 11.0000 - val_accuracy: 0.9956 - val_precision: 0.2107 - val_recall: 0.8226 - val_auc: 0.9731 - val_prc: 0.7339
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3008 - tp: 271.0000 - fp: 3478.0000 - tn: 178477.0000 - fn: 50.0000 - accuracy: 0.9806 - precision: 0.0723 - recall: 0.8442 - auc: 0.9549 - prc: 0.4444 - val_loss: 0.0398 - val_tp: 52.0000 - val_fp: 328.0000 - val_tn: 45179.0000 - val_fn: 10.0000 - val_accuracy: 0.9926 - val_precision: 0.1368 - val_recall: 0.8387 - val_auc: 0.9767 - val_prc: 0.6973
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2619 - tp: 280.0000 - fp: 4550.0000 - tn: 177405.0000 - fn: 41.0000 - accuracy: 0.9748 - precision: 0.0580 - recall: 0.8723 - auc: 0.9591 - prc: 0.3881 - val_loss: 0.0491 - val_tp: 53.0000 - val_fp: 449.0000 - val_tn: 45058.0000 - val_fn: 9.0000 - val_accuracy: 0.9899 - val_precision: 0.1056 - val_recall: 0.8548 - val_auc: 0.9793 - val_prc: 0.6283
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2498 - tp: 280.0000 - fp: 5328.0000 - tn: 176627.0000 - fn: 41.0000 - accuracy: 0.9705 - precision: 0.0499 - recall: 0.8723 - auc: 0.9644 - prc: 0.3410 - val_loss: 0.0580 - val_tp: 53.0000 - val_fp: 539.0000 - val_tn: 44968.0000 - val_fn: 9.0000 - val_accuracy: 0.9880 - val_precision: 0.0895 - val_recall: 0.8548 - val_auc: 0.9797 - val_prc: 0.6020
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2438 - tp: 282.0000 - fp: 5982.0000 - tn: 175973.0000 - fn: 39.0000 - accuracy: 0.9670 - precision: 0.0450 - recall: 0.8785 - auc: 0.9650 - prc: 0.3212 - val_loss: 0.0652 - val_tp: 52.0000 - val_fp: 604.0000 - val_tn: 44903.0000 - val_fn: 10.0000 - val_accuracy: 0.9865 - val_precision: 0.0793 - val_recall: 0.8387 - val_auc: 0.9810 - val_prc: 0.5585
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2616 - tp: 285.0000 - fp: 6410.0000 - tn: 175545.0000 - fn: 36.0000 - accuracy: 0.9646 - precision: 0.0426 - recall: 0.8879 - auc: 0.9569 - prc: 0.2913 - val_loss: 0.0683 - val_tp: 53.0000 - val_fp: 640.0000 - val_tn: 44867.0000 - val_fn: 9.0000 - val_accuracy: 0.9858 - val_precision: 0.0765 - val_recall: 0.8548 - val_auc: 0.9819 - val_prc: 0.5515
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.1943 - tp: 288.0000 - fp: 6518.0000 - tn: 175437.0000 - fn: 33.0000 - accuracy: 0.9641 - precision: 0.0423 - recall: 0.8972 - auc: 0.9766 - prc: 0.2935 - val_loss: 0.0728 - val_tp: 53.0000 - val_fp: 698.0000 - val_tn: 44809.0000 - val_fn: 9.0000 - val_accuracy: 0.9845 - val_precision: 0.0706 - val_recall: 0.8548 - val_auc: 0.9827 - val_prc: 0.5057
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2090 - tp: 290.0000 - fp: 6828.0000 - tn: 175127.0000 - fn: 31.0000 - accuracy: 0.9624 - precision: 0.0407 - recall: 0.9034 - auc: 0.9717 - prc: 0.2726 - val_loss: 0.0739 - val_tp: 53.0000 - val_fp: 709.0000 - val_tn: 44798.0000 - val_fn: 9.0000 - val_accuracy: 0.9842 - val_precision: 0.0696 - val_recall: 0.8548 - val_auc: 0.9840 - val_prc: 0.4996
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2204 - tp: 289.0000 - fp: 6583.0000 - tn: 175372.0000 - fn: 32.0000 - accuracy: 0.9637 - precision: 0.0421 - recall: 0.9003 - auc: 0.9690 - prc: 0.2758 - val_loss: 0.0738 - val_tp: 53.0000 - val_fp: 702.0000 - val_tn: 44805.0000 - val_fn: 9.0000 - val_accuracy: 0.9844 - val_precision: 0.0702 - val_recall: 0.8548 - val_auc: 0.9843 - val_prc: 0.4996
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2008 - tp: 292.0000 - fp: 6452.0000 - tn: 175503.0000 - fn: 29.0000 - accuracy: 0.9644 - precision: 0.0433 - recall: 0.9097 - auc: 0.9733 - prc: 0.2783 - val_loss: 0.0748 - val_tp: 53.0000 - val_fp: 702.0000 - val_tn: 44805.0000 - val_fn: 9.0000 - val_accuracy: 0.9844 - val_precision: 0.0702 - val_recall: 0.8548 - val_auc: 0.9851 - val_prc: 0.4883
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.1883 - tp: 290.0000 - fp: 7049.0000 - tn: 174906.0000 - fn: 31.0000 - accuracy: 0.9612 - precision: 0.0395 - recall: 0.9034 - auc: 0.9775 - prc: 0.2634 - val_loss: 0.0800 - val_tp: 53.0000 - val_fp: 769.0000 - val_tn: 44738.0000 - val_fn: 9.0000 - val_accuracy: 0.9829 - val_precision: 0.0645 - val_recall: 0.8548 - val_auc: 0.9854 - val_prc: 0.4571
Restoring model weights from the end of the best epoch.
Epoch 00015: early stopping

Sprawdź historię szkoleń

plot_metrics(weighted_history)

png

Oceń dane

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.03160524740815163
tp :  94.0
fp :  231.0
tn :  56622.0
fn :  15.0
accuracy :  0.9956813454627991
precision :  0.2892307639122009
recall :  0.8623853325843811
auc :  0.9544669985771179
prc :  0.7459506392478943

Legitimate Transactions Detected (True Negatives):  56622
Legitimate Transactions Incorrectly Detected (False Positives):  231
Fraudulent Transactions Missed (False Negatives):  15
Fraudulent Transactions Detected (True Positives):  94
Total Fraudulent Transactions:  109

png

Tutaj widać, że przy wagach klas dokładność i precyzja są niższe, ponieważ jest więcej fałszywych trafień, ale odwrotnie, przypominanie i AUC są wyższe, ponieważ model znalazł również więcej prawdziwych trafień. Pomimo niższej dokładności model ten ma większą przypomnienie (i identyfikuje więcej nieuczciwych transakcji). Oczywiście oba rodzaje błędów wiążą się z pewnym kosztem (nie chciałbyś też nękać użytkowników, oznaczając zbyt wiele legalnych transakcji jako fałszywych). Uważnie rozważ kompromisy między tymi różnymi typami błędów w swojej aplikacji.

Wykreśl ROC

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fb5044af7d0>

png

Wykreśl AUPRC

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fb90c19d6d0>

png

Nadpróbkowanie

Oversample z klasy mniejszości

Pokrewnym podejściem byłoby ponowne próbkowanie zbioru danych przez nadpróbkowanie klasy mniejszości.

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

Korzystanie z NumPy

Możesz ręcznie zrównoważyć zbiór danych, wybierając odpowiednią liczbę losowych wskaźników z pozytywnych przykładów:

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181955, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363910, 29)

Korzystanie z tf.data

Jeśli używasz tf.data najłatwiejszym sposobem tworzenia zrównoważonych przykładów jest rozpoczęcie od positive i negative zestawu danych, a następnie scalenie ich. Więcej przykładów znajdziesz w przewodniku tf.data .

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

Każdy zestaw danych zawiera pary (feature, label) :

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [-1.33936943  3.11520371 -4.01205511  3.90726777  1.17349309 -1.75281185
 -2.63199526 -1.24346046 -4.61025869 -5.          5.         -5.
 -3.14523198 -5.         -0.87500779 -1.94202387 -1.96916724  0.32933292
 -2.48801402  1.24071362 -0.72319929  0.3052153  -2.19534696 -0.83443772
  0.73356809  0.81995722  1.96819284  1.90706374 -1.45216499]

Label:  1

Połącz je razem, używając experimental.sample_from_datasets :

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.50830078125

Aby użyć tego zbioru danych, potrzebujesz liczby kroków na epokę.

Definicja „epoki” w tym przypadku jest mniej jasna. Załóżmy, że jest to liczba partii wymagana do jednorazowego wyświetlenia każdego negatywnego przykładu:

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

Trenuj na nadpróbkowanych danych

Teraz spróbuj wytrenować model z ponownie próbkowanym zestawem danych zamiast używać wag klas, aby zobaczyć porównanie tych metod.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks=[early_stopping],
    validation_data=val_ds)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
278/278 [==============================] - 10s 29ms/step - loss: 0.4602 - tp: 246611.0000 - fp: 96452.0000 - tn: 245127.0000 - fn: 38116.0000 - accuracy: 0.7851 - precision: 0.7189 - recall: 0.8661 - auc: 0.9039 - prc: 0.9207 - val_loss: 0.2878 - val_tp: 54.0000 - val_fp: 1400.0000 - val_tn: 44107.0000 - val_fn: 8.0000 - val_accuracy: 0.9691 - val_precision: 0.0371 - val_recall: 0.8710 - val_auc: 0.9552 - val_prc: 0.7466
Epoch 2/100
278/278 [==============================] - 7s 27ms/step - loss: 0.2089 - tp: 260504.0000 - fp: 18142.0000 - tn: 266849.0000 - fn: 23849.0000 - accuracy: 0.9262 - precision: 0.9349 - recall: 0.9161 - auc: 0.9737 - prc: 0.9795 - val_loss: 0.1302 - val_tp: 53.0000 - val_fp: 613.0000 - val_tn: 44894.0000 - val_fn: 9.0000 - val_accuracy: 0.9864 - val_precision: 0.0796 - val_recall: 0.8548 - val_auc: 0.9621 - val_prc: 0.7513
Epoch 3/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1511 - tp: 264877.0000 - fp: 9863.0000 - tn: 274507.0000 - fn: 20097.0000 - accuracy: 0.9474 - precision: 0.9641 - recall: 0.9295 - auc: 0.9861 - prc: 0.9882 - val_loss: 0.0848 - val_tp: 53.0000 - val_fp: 544.0000 - val_tn: 44963.0000 - val_fn: 9.0000 - val_accuracy: 0.9879 - val_precision: 0.0888 - val_recall: 0.8548 - val_auc: 0.9636 - val_prc: 0.7553
Epoch 4/100
278/278 [==============================] - 7s 26ms/step - loss: 0.1289 - tp: 266348.0000 - fp: 8114.0000 - tn: 276909.0000 - fn: 17973.0000 - accuracy: 0.9542 - precision: 0.9704 - recall: 0.9368 - auc: 0.9899 - prc: 0.9911 - val_loss: 0.0688 - val_tp: 53.0000 - val_fp: 528.0000 - val_tn: 44979.0000 - val_fn: 9.0000 - val_accuracy: 0.9882 - val_precision: 0.0912 - val_recall: 0.8548 - val_auc: 0.9640 - val_prc: 0.7179
Epoch 5/100
278/278 [==============================] - 7s 26ms/step - loss: 0.1156 - tp: 268867.0000 - fp: 7660.0000 - tn: 276576.0000 - fn: 16241.0000 - accuracy: 0.9580 - precision: 0.9723 - recall: 0.9430 - auc: 0.9922 - prc: 0.9928 - val_loss: 0.0598 - val_tp: 53.0000 - val_fp: 543.0000 - val_tn: 44964.0000 - val_fn: 9.0000 - val_accuracy: 0.9879 - val_precision: 0.0889 - val_recall: 0.8548 - val_auc: 0.9619 - val_prc: 0.6723
Epoch 6/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1060 - tp: 269835.0000 - fp: 6989.0000 - tn: 277623.0000 - fn: 14897.0000 - accuracy: 0.9616 - precision: 0.9748 - recall: 0.9477 - auc: 0.9937 - prc: 0.9939 - val_loss: 0.0537 - val_tp: 53.0000 - val_fp: 528.0000 - val_tn: 44979.0000 - val_fn: 9.0000 - val_accuracy: 0.9882 - val_precision: 0.0912 - val_recall: 0.8548 - val_auc: 0.9633 - val_prc: 0.6831
Epoch 7/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0987 - tp: 270645.0000 - fp: 6666.0000 - tn: 277841.0000 - fn: 14192.0000 - accuracy: 0.9634 - precision: 0.9760 - recall: 0.9502 - auc: 0.9947 - prc: 0.9947 - val_loss: 0.0496 - val_tp: 53.0000 - val_fp: 527.0000 - val_tn: 44980.0000 - val_fn: 9.0000 - val_accuracy: 0.9882 - val_precision: 0.0914 - val_recall: 0.8548 - val_auc: 0.9598 - val_prc: 0.6619
Epoch 8/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0933 - tp: 271300.0000 - fp: 6455.0000 - tn: 278026.0000 - fn: 13563.0000 - accuracy: 0.9648 - precision: 0.9768 - recall: 0.9524 - auc: 0.9953 - prc: 0.9952 - val_loss: 0.0457 - val_tp: 53.0000 - val_fp: 516.0000 - val_tn: 44991.0000 - val_fn: 9.0000 - val_accuracy: 0.9885 - val_precision: 0.0931 - val_recall: 0.8548 - val_auc: 0.9609 - val_prc: 0.6647
Epoch 9/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0891 - tp: 271980.0000 - fp: 6358.0000 - tn: 278122.0000 - fn: 12884.0000 - accuracy: 0.9662 - precision: 0.9772 - recall: 0.9548 - auc: 0.9957 - prc: 0.9955 - val_loss: 0.0429 - val_tp: 53.0000 - val_fp: 482.0000 - val_tn: 45025.0000 - val_fn: 9.0000 - val_accuracy: 0.9892 - val_precision: 0.0991 - val_recall: 0.8548 - val_auc: 0.9617 - val_prc: 0.6651
Epoch 10/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0853 - tp: 271960.0000 - fp: 6166.0000 - tn: 278831.0000 - fn: 12387.0000 - accuracy: 0.9674 - precision: 0.9778 - recall: 0.9564 - auc: 0.9960 - prc: 0.9958 - val_loss: 0.0392 - val_tp: 53.0000 - val_fp: 435.0000 - val_tn: 45072.0000 - val_fn: 9.0000 - val_accuracy: 0.9903 - val_precision: 0.1086 - val_recall: 0.8548 - val_auc: 0.9627 - val_prc: 0.6659
Epoch 11/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0817 - tp: 272529.0000 - fp: 5943.0000 - tn: 279154.0000 - fn: 11718.0000 - accuracy: 0.9690 - precision: 0.9787 - recall: 0.9588 - auc: 0.9964 - prc: 0.9961 - val_loss: 0.0374 - val_tp: 53.0000 - val_fp: 421.0000 - val_tn: 45086.0000 - val_fn: 9.0000 - val_accuracy: 0.9906 - val_precision: 0.1118 - val_recall: 0.8548 - val_auc: 0.9629 - val_prc: 0.6464
Epoch 12/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0796 - tp: 273405.0000 - fp: 5994.0000 - tn: 278444.0000 - fn: 11501.0000 - accuracy: 0.9693 - precision: 0.9785 - recall: 0.9596 - auc: 0.9965 - prc: 0.9962 - val_loss: 0.0361 - val_tp: 54.0000 - val_fp: 407.0000 - val_tn: 45100.0000 - val_fn: 8.0000 - val_accuracy: 0.9909 - val_precision: 0.1171 - val_recall: 0.8710 - val_auc: 0.9629 - val_prc: 0.6475
Epoch 13/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0772 - tp: 272567.0000 - fp: 5767.0000 - tn: 279618.0000 - fn: 11392.0000 - accuracy: 0.9699 - precision: 0.9793 - recall: 0.9599 - auc: 0.9968 - prc: 0.9964 - val_loss: 0.0344 - val_tp: 54.0000 - val_fp: 392.0000 - val_tn: 45115.0000 - val_fn: 8.0000 - val_accuracy: 0.9912 - val_precision: 0.1211 - val_recall: 0.8710 - val_auc: 0.9563 - val_prc: 0.6478
Restoring model weights from the end of the best epoch.
Epoch 00013: early stopping

Gdyby proces uczenia uwzględniał cały zestaw danych przy każdej aktualizacji gradientu, to nadpróbkowanie byłoby zasadniczo identyczne z wagą klasy.

Ale podczas uczenia modelu wsadowo, tak jak tutaj, nadpróbkowane dane zapewniają gładszy sygnał gradientu: zamiast każdego pozytywnego przykładu wyświetlanego w jednej partii z dużą wagą, są one wyświetlane w wielu różnych partiach za każdym razem z mała waga.

Ten gładszy sygnał gradientu ułatwia trenowanie modelu.

Sprawdź historię szkoleń

Zauważ, że rozkłady metryk będą się tutaj różnić, ponieważ dane uczące mają zupełnie inny rozkład niż dane walidacyjne i testowe.

plot_metrics(resampled_history)

png

Ponowne szkolenie

Ponieważ trening jest łatwiejszy na zrównoważonych danych, powyższa procedura treningowa może szybko przesadzić.

Więc tf.keras.callbacks.EarlyStopping epoki, aby dać tf.keras.callbacks.EarlyStopping lepszą kontrolę nad tym, kiedy przestać trenować.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch=20,
    epochs=10*EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_ds))
Epoch 1/1000
20/20 [==============================] - 3s 69ms/step - loss: 1.0991 - tp: 11894.0000 - fp: 13394.0000 - tn: 52609.0000 - fn: 8632.0000 - accuracy: 0.7454 - precision: 0.4703 - recall: 0.5795 - auc: 0.8205 - prc: 0.6241 - val_loss: 0.8598 - val_tp: 56.0000 - val_fp: 30246.0000 - val_tn: 15261.0000 - val_fn: 6.0000 - val_accuracy: 0.3361 - val_precision: 0.0018 - val_recall: 0.9032 - val_auc: 0.8506 - val_prc: 0.1076
Epoch 2/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.7095 - tp: 16028.0000 - fp: 12452.0000 - tn: 8068.0000 - fn: 4412.0000 - accuracy: 0.5883 - precision: 0.5628 - recall: 0.7841 - auc: 0.7405 - prc: 0.8152 - val_loss: 0.8063 - val_tp: 59.0000 - val_fp: 27532.0000 - val_tn: 17975.0000 - val_fn: 3.0000 - val_accuracy: 0.3958 - val_precision: 0.0021 - val_recall: 0.9516 - val_auc: 0.9097 - val_prc: 0.4958
Epoch 3/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.5807 - tp: 17682.0000 - fp: 11457.0000 - tn: 8997.0000 - fn: 2824.0000 - accuracy: 0.6513 - precision: 0.6068 - recall: 0.8623 - auc: 0.8341 - prc: 0.8831 - val_loss: 0.7329 - val_tp: 58.0000 - val_fp: 23131.0000 - val_tn: 22376.0000 - val_fn: 4.0000 - val_accuracy: 0.4923 - val_precision: 0.0025 - val_recall: 0.9355 - val_auc: 0.9333 - val_prc: 0.6207
Epoch 4/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.5082 - tp: 18269.0000 - fp: 9998.0000 - tn: 10312.0000 - fn: 2381.0000 - accuracy: 0.6978 - precision: 0.6463 - recall: 0.8847 - auc: 0.8729 - prc: 0.9133 - val_loss: 0.6624 - val_tp: 58.0000 - val_fp: 18369.0000 - val_tn: 27138.0000 - val_fn: 4.0000 - val_accuracy: 0.5968 - val_precision: 0.0031 - val_recall: 0.9355 - val_auc: 0.9426 - val_prc: 0.6590
Epoch 5/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.4708 - tp: 18236.0000 - fp: 8857.0000 - tn: 11578.0000 - fn: 2289.0000 - accuracy: 0.7279 - precision: 0.6731 - recall: 0.8885 - auc: 0.8899 - prc: 0.9244 - val_loss: 0.6013 - val_tp: 59.0000 - val_fp: 13858.0000 - val_tn: 31649.0000 - val_fn: 3.0000 - val_accuracy: 0.6958 - val_precision: 0.0042 - val_recall: 0.9516 - val_auc: 0.9492 - val_prc: 0.6761
Epoch 6/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.4312 - tp: 18413.0000 - fp: 7647.0000 - tn: 12816.0000 - fn: 2084.0000 - accuracy: 0.7624 - precision: 0.7066 - recall: 0.8983 - auc: 0.9077 - prc: 0.9358 - val_loss: 0.5480 - val_tp: 59.0000 - val_fp: 10200.0000 - val_tn: 35307.0000 - val_fn: 3.0000 - val_accuracy: 0.7761 - val_precision: 0.0058 - val_recall: 0.9516 - val_auc: 0.9535 - val_prc: 0.6934
Epoch 7/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.4040 - tp: 18252.0000 - fp: 6667.0000 - tn: 13961.0000 - fn: 2080.0000 - accuracy: 0.7865 - precision: 0.7325 - recall: 0.8977 - auc: 0.9167 - prc: 0.9421 - val_loss: 0.5002 - val_tp: 58.0000 - val_fp: 7300.0000 - val_tn: 38207.0000 - val_fn: 4.0000 - val_accuracy: 0.8397 - val_precision: 0.0079 - val_recall: 0.9355 - val_auc: 0.9568 - val_prc: 0.7066
Epoch 8/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.3761 - tp: 18491.0000 - fp: 5548.0000 - tn: 14922.0000 - fn: 1999.0000 - accuracy: 0.8157 - precision: 0.7692 - recall: 0.9024 - auc: 0.9266 - prc: 0.9496 - val_loss: 0.4588 - val_tp: 57.0000 - val_fp: 5228.0000 - val_tn: 40279.0000 - val_fn: 5.0000 - val_accuracy: 0.8852 - val_precision: 0.0108 - val_recall: 0.9194 - val_auc: 0.9593 - val_prc: 0.7155
Epoch 9/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.3496 - tp: 18503.0000 - fp: 4750.0000 - tn: 15786.0000 - fn: 1921.0000 - accuracy: 0.8371 - precision: 0.7957 - recall: 0.9059 - auc: 0.9358 - prc: 0.9551 - val_loss: 0.4219 - val_tp: 57.0000 - val_fp: 3803.0000 - val_tn: 41704.0000 - val_fn: 5.0000 - val_accuracy: 0.9164 - val_precision: 0.0148 - val_recall: 0.9194 - val_auc: 0.9610 - val_prc: 0.7346
Epoch 10/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.3301 - tp: 18571.0000 - fp: 4205.0000 - tn: 16321.0000 - fn: 1863.0000 - accuracy: 0.8519 - precision: 0.8154 - recall: 0.9088 - auc: 0.9423 - prc: 0.9596 - val_loss: 0.3889 - val_tp: 56.0000 - val_fp: 2848.0000 - val_tn: 42659.0000 - val_fn: 6.0000 - val_accuracy: 0.9374 - val_precision: 0.0193 - val_recall: 0.9032 - val_auc: 0.9615 - val_prc: 0.7390
Epoch 11/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.3116 - tp: 18608.0000 - fp: 3632.0000 - tn: 16870.0000 - fn: 1850.0000 - accuracy: 0.8662 - precision: 0.8367 - recall: 0.9096 - auc: 0.9479 - prc: 0.9632 - val_loss: 0.3587 - val_tp: 54.0000 - val_fp: 2242.0000 - val_tn: 43265.0000 - val_fn: 8.0000 - val_accuracy: 0.9506 - val_precision: 0.0235 - val_recall: 0.8710 - val_auc: 0.9605 - val_prc: 0.7414
Epoch 12/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2980 - tp: 18692.0000 - fp: 3218.0000 - tn: 17110.0000 - fn: 1940.0000 - accuracy: 0.8741 - precision: 0.8531 - recall: 0.9060 - auc: 0.9502 - prc: 0.9649 - val_loss: 0.3320 - val_tp: 54.0000 - val_fp: 1833.0000 - val_tn: 43674.0000 - val_fn: 8.0000 - val_accuracy: 0.9596 - val_precision: 0.0286 - val_recall: 0.8710 - val_auc: 0.9580 - val_prc: 0.7436
Epoch 13/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2827 - tp: 18691.0000 - fp: 2746.0000 - tn: 17656.0000 - fn: 1867.0000 - accuracy: 0.8874 - precision: 0.8719 - recall: 0.9092 - auc: 0.9555 - prc: 0.9680 - val_loss: 0.3080 - val_tp: 54.0000 - val_fp: 1564.0000 - val_tn: 43943.0000 - val_fn: 8.0000 - val_accuracy: 0.9655 - val_precision: 0.0334 - val_recall: 0.8710 - val_auc: 0.9562 - val_prc: 0.7460
Epoch 14/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2700 - tp: 18659.0000 - fp: 2414.0000 - tn: 18008.0000 - fn: 1879.0000 - accuracy: 0.8952 - precision: 0.8854 - recall: 0.9085 - auc: 0.9581 - prc: 0.9697 - val_loss: 0.2863 - val_tp: 54.0000 - val_fp: 1369.0000 - val_tn: 44138.0000 - val_fn: 8.0000 - val_accuracy: 0.9698 - val_precision: 0.0379 - val_recall: 0.8710 - val_auc: 0.9549 - val_prc: 0.7486
Epoch 15/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2568 - tp: 18693.0000 - fp: 2148.0000 - tn: 18285.0000 - fn: 1834.0000 - accuracy: 0.9028 - precision: 0.8969 - recall: 0.9107 - auc: 0.9627 - prc: 0.9725 - val_loss: 0.2659 - val_tp: 54.0000 - val_fp: 1162.0000 - val_tn: 44345.0000 - val_fn: 8.0000 - val_accuracy: 0.9743 - val_precision: 0.0444 - val_recall: 0.8710 - val_auc: 0.9546 - val_prc: 0.7519
Epoch 16/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2476 - tp: 18763.0000 - fp: 1975.0000 - tn: 18456.0000 - fn: 1766.0000 - accuracy: 0.9087 - precision: 0.9048 - recall: 0.9140 - auc: 0.9655 - prc: 0.9742 - val_loss: 0.2474 - val_tp: 54.0000 - val_fp: 994.0000 - val_tn: 44513.0000 - val_fn: 8.0000 - val_accuracy: 0.9780 - val_precision: 0.0515 - val_recall: 0.8710 - val_auc: 0.9549 - val_prc: 0.7541
Epoch 17/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2380 - tp: 18891.0000 - fp: 1692.0000 - tn: 18558.0000 - fn: 1819.0000 - accuracy: 0.9143 - precision: 0.9178 - recall: 0.9122 - auc: 0.9667 - prc: 0.9753 - val_loss: 0.2312 - val_tp: 54.0000 - val_fp: 886.0000 - val_tn: 44621.0000 - val_fn: 8.0000 - val_accuracy: 0.9804 - val_precision: 0.0574 - val_recall: 0.8710 - val_auc: 0.9552 - val_prc: 0.7548
Epoch 18/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2293 - tp: 18588.0000 - fp: 1562.0000 - tn: 19042.0000 - fn: 1768.0000 - accuracy: 0.9187 - precision: 0.9225 - recall: 0.9131 - auc: 0.9689 - prc: 0.9762 - val_loss: 0.2163 - val_tp: 54.0000 - val_fp: 813.0000 - val_tn: 44694.0000 - val_fn: 8.0000 - val_accuracy: 0.9820 - val_precision: 0.0623 - val_recall: 0.8710 - val_auc: 0.9556 - val_prc: 0.7600
Epoch 19/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2220 - tp: 18694.0000 - fp: 1437.0000 - tn: 19042.0000 - fn: 1787.0000 - accuracy: 0.9213 - precision: 0.9286 - recall: 0.9127 - auc: 0.9705 - prc: 0.9774 - val_loss: 0.2030 - val_tp: 54.0000 - val_fp: 749.0000 - val_tn: 44758.0000 - val_fn: 8.0000 - val_accuracy: 0.9834 - val_precision: 0.0672 - val_recall: 0.8710 - val_auc: 0.9561 - val_prc: 0.7483
Epoch 20/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2126 - tp: 18608.0000 - fp: 1300.0000 - tn: 19345.0000 - fn: 1707.0000 - accuracy: 0.9266 - precision: 0.9347 - recall: 0.9160 - auc: 0.9729 - prc: 0.9789 - val_loss: 0.1905 - val_tp: 54.0000 - val_fp: 707.0000 - val_tn: 44800.0000 - val_fn: 8.0000 - val_accuracy: 0.9843 - val_precision: 0.0710 - val_recall: 0.8710 - val_auc: 0.9562 - val_prc: 0.7488
Epoch 21/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2057 - tp: 18731.0000 - fp: 1135.0000 - tn: 19379.0000 - fn: 1715.0000 - accuracy: 0.9304 - precision: 0.9429 - recall: 0.9161 - auc: 0.9743 - prc: 0.9798 - val_loss: 0.1800 - val_tp: 54.0000 - val_fp: 693.0000 - val_tn: 44814.0000 - val_fn: 8.0000 - val_accuracy: 0.9846 - val_precision: 0.0723 - val_recall: 0.8710 - val_auc: 0.9572 - val_prc: 0.7497
Epoch 22/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2003 - tp: 18762.0000 - fp: 1143.0000 - tn: 19358.0000 - fn: 1697.0000 - accuracy: 0.9307 - precision: 0.9426 - recall: 0.9171 - auc: 0.9756 - prc: 0.9809 - val_loss: 0.1701 - val_tp: 54.0000 - val_fp: 671.0000 - val_tn: 44836.0000 - val_fn: 8.0000 - val_accuracy: 0.9851 - val_precision: 0.0745 - val_recall: 0.8710 - val_auc: 0.9580 - val_prc: 0.7492
Epoch 23/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1939 - tp: 18770.0000 - fp: 1102.0000 - tn: 19405.0000 - fn: 1683.0000 - accuracy: 0.9320 - precision: 0.9445 - recall: 0.9177 - auc: 0.9773 - prc: 0.9818 - val_loss: 0.1611 - val_tp: 54.0000 - val_fp: 650.0000 - val_tn: 44857.0000 - val_fn: 8.0000 - val_accuracy: 0.9856 - val_precision: 0.0767 - val_recall: 0.8710 - val_auc: 0.9587 - val_prc: 0.7492
Epoch 24/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1891 - tp: 18890.0000 - fp: 1023.0000 - tn: 19387.0000 - fn: 1660.0000 - accuracy: 0.9345 - precision: 0.9486 - recall: 0.9192 - auc: 0.9784 - prc: 0.9828 - val_loss: 0.1529 - val_tp: 54.0000 - val_fp: 643.0000 - val_tn: 44864.0000 - val_fn: 8.0000 - val_accuracy: 0.9857 - val_precision: 0.0775 - val_recall: 0.8710 - val_auc: 0.9597 - val_prc: 0.7491
Epoch 25/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1839 - tp: 18838.0000 - fp: 958.0000 - tn: 19513.0000 - fn: 1651.0000 - accuracy: 0.9363 - precision: 0.9516 - recall: 0.9194 - auc: 0.9787 - prc: 0.9831 - val_loss: 0.1457 - val_tp: 54.0000 - val_fp: 630.0000 - val_tn: 44877.0000 - val_fn: 8.0000 - val_accuracy: 0.9860 - val_precision: 0.0789 - val_recall: 0.8710 - val_auc: 0.9606 - val_prc: 0.7495
Epoch 26/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1755 - tp: 18866.0000 - fp: 945.0000 - tn: 19579.0000 - fn: 1570.0000 - accuracy: 0.9386 - precision: 0.9523 - recall: 0.9232 - auc: 0.9815 - prc: 0.9849 - val_loss: 0.1395 - val_tp: 54.0000 - val_fp: 624.0000 - val_tn: 44883.0000 - val_fn: 8.0000 - val_accuracy: 0.9861 - val_precision: 0.0796 - val_recall: 0.8710 - val_auc: 0.9615 - val_prc: 0.7511
Epoch 27/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1735 - tp: 18768.0000 - fp: 897.0000 - tn: 19673.0000 - fn: 1622.0000 - accuracy: 0.9385 - precision: 0.9544 - recall: 0.9205 - auc: 0.9814 - prc: 0.9848 - val_loss: 0.1333 - val_tp: 54.0000 - val_fp: 615.0000 - val_tn: 44892.0000 - val_fn: 8.0000 - val_accuracy: 0.9863 - val_precision: 0.0807 - val_recall: 0.8710 - val_auc: 0.9615 - val_prc: 0.7509
Epoch 28/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.1708 - tp: 18982.0000 - fp: 861.0000 - tn: 19500.0000 - fn: 1617.0000 - accuracy: 0.9395 - precision: 0.9566 - recall: 0.9215 - auc: 0.9822 - prc: 0.9854 - val_loss: 0.1277 - val_tp: 53.0000 - val_fp: 607.0000 - val_tn: 44900.0000 - val_fn: 9.0000 - val_accuracy: 0.9865 - val_precision: 0.0803 - val_recall: 0.8548 - val_auc: 0.9618 - val_prc: 0.7511
Restoring model weights from the end of the best epoch.
Epoch 00028: early stopping

Sprawdź ponownie historię szkoleń

plot_metrics(resampled_history)

png

Oceń dane

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
loss :  0.21696887910366058
tp :  94.0
fp :  1000.0
tn :  55853.0
fn :  15.0
accuracy :  0.9821811318397522
precision :  0.0859232172369957
recall :  0.8623853325843811
auc :  0.9408414959907532
prc :  0.7525754570960999

Legitimate Transactions Detected (True Negatives):  55853
Legitimate Transactions Incorrectly Detected (False Positives):  1000
Fraudulent Transactions Missed (False Negatives):  15
Fraudulent Transactions Detected (True Positives):  94
Total Fraudulent Transactions:  109

png

Wykreśl ROC

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fb5040c2e10>

png

Wykreśl AUPRC

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_prc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_prc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fb5043e6cd0>

png

Zastosowanie tego samouczka do swojego problemu

Niezrównoważona klasyfikacja danych jest z natury trudnym zadaniem, ponieważ jest tak mało próbek, z których można się uczyć. Zawsze powinieneś zacząć od danych i zrobić wszystko, co w Twojej mocy, aby zebrać jak najwięcej próbek i zastanowić się, jakie funkcje mogą być istotne, aby model mógł jak najlepiej wykorzystać twoją klasę mniejszości. W pewnym momencie Twój model może mieć trudności z ulepszeniem i uzyskaniem pożądanych wyników, dlatego ważne jest, aby pamiętać o kontekście problemu i kompromisach między różnymi rodzajami błędów.