過学習と学習不足について知る

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

いつものように、この例のプログラムはtf.keras APIを使用します。詳しくはTensorFlowのKeras guideを参照してください。

これまでの例、つまり、映画レビューの分類と燃費の推定では、検証用データでのモデルの正解率が、数エポックでピークを迎え、その後低下するという現象が見られました。

言い換えると、モデルが訓練用データを過学習したと考えられます。過学習への対処の仕方を学ぶことは重要です。訓練用データセットで高い正解率を達成することは難しくありませんが、我々は、(これまで見たこともない)テスト用データに汎化したモデルを開発したいのです。

過学習の反対語は学習不足(underfitting)です。学習不足は、モデルがテストデータに対してまだ改善の余地がある場合に発生します。学習不足の原因は様々です。モデルが十分強力でないとか、正則化のしすぎだとか、単に訓練時間が短すぎるといった理由があります。学習不足は、訓練用データの中の関連したパターンを学習しきっていないということを意味します。

モデルの訓練をやりすぎると、モデルは過学習を始め、訓練用データの中のパターンで、テストデータには一般的ではないパターンを学習します。我々は、過学習と学習不足の中間を目指す必要があります。これから見ていくように、ちょうどよいエポック数だけ訓練を行うというのは必要なスキルなのです。

過学習を防止するための、最良の解決策は、より多くの訓練用データを使うことです。多くのデータで訓練を行えば行うほど、モデルは自然により汎化していく様になります。これが不可能な場合、次善の策は正則化のようなテクニックを使うことです。正則化は、モデルに保存される情報の量とタイプに制約を課すものです。ネットワークが少数のパターンしか記憶できなければ、最適化プロセスにより、最も主要なパターンのみを学習することになり、より汎化される可能性が高くなります。

このノートブックでは、重みの正則化とドロップアウトという、よく使われる2つの正則化テクニックをご紹介します。これらを使って、IMDBの映画レビューを分類するノートブックの改善を図ります。

import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
2.3.0

IMDBデータセットのダウンロード

以前のノートブックで使用したエンベディングの代わりに、ここでは文をマルチホットエンコードします。このモデルは、訓練用データセットをすぐに過学習します。このモデルを使って、過学習がいつ起きるかということと、どうやって過学習と戦うかをデモします。

リストをマルチホットエンコードすると言うのは、0と1のベクトルにするということです。具体的にいうと、例えば[3, 5]というシーケンスを、インデックス3と5の値が1で、それ以外がすべて0の、10,000次元のベクトルに変換するということを意味します。

NUM_WORDS = 10000

(train_data, train_labels), (test_data, test_labels) = keras.datasets.imdb.load_data(num_words=NUM_WORDS)

def multi_hot_sequences(sequences, dimension):
    # 形状が (len(sequences), dimension)ですべて0の行列を作る
    results = np.zeros((len(sequences), dimension))
    for i, word_indices in enumerate(sequences):
        results[i, word_indices] = 1.0  # 特定のインデックスに対してresults[i] を1に設定する
    return results


train_data = multi_hot_sequences(train_data, dimension=NUM_WORDS)
test_data = multi_hot_sequences(test_data, dimension=NUM_WORDS)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
17465344/17464789 [==============================] - 0s 0us/step

結果として得られるマルチホットベクトルの1つを見てみましょう。単語のインデックスは頻度順にソートされています。このため、インデックスが0に近いほど1が多く出現するはずです。分布を見てみましょう。

plt.plot(train_data[0])
[<matplotlib.lines.Line2D at 0x7f045af1a390>]

png

過学習のデモ

過学習を防止するための最も単純な方法は、モデルのサイズ、すなわち、モデル内の学習可能なパラメータの数を小さくすることです(学習パラメータの数は、層の数と層ごとのユニット数で決まります)。ディープラーニングでは、モデルの学習可能なパラメータ数を、しばしばモデルの「キャパシティ」と呼びます。直感的に考えれば、パラメータ数の多いモデルほど「記憶容量」が大きくなり、訓練用のサンプルとその目的変数の間の辞書のようなマッピングをたやすく学習することができます。このマッピングには汎化能力がまったくなく、これまで見たことが無いデータを使って予測をする際には役に立ちません。

ディープラーニングのモデルは訓練用データに適応しやすいけれど、本当のチャレレンジは汎化であって適応ではないということを、肝に銘じておく必要があります。

一方、ネットワークの記憶容量が限られている場合、前述のようなマッピングを簡単に学習することはできません。損失を減らすためには、より予測能力が高い圧縮された表現を学習しなければなりません。同時に、モデルを小さくしすぎると、訓練用データに適応するのが難しくなります。「多すぎる容量」と「容量不足」の間にちょうどよい容量があるのです。

残念ながら、(層の数や、層ごとの大きさといった)モデルの適切なサイズやアーキテクチャを決める魔法の方程式はありません。一連の異なるアーキテクチャを使って実験を行う必要があります。

適切なモデルのサイズを見つけるには、比較的少ない層の数とパラメータから始めるのがベストです。それから、検証用データでの損失値の改善が見られなくなるまで、徐々に層の大きさを増やしたり、新たな層を加えたりします。映画レビューの分類ネットワークでこれを試してみましょう。

比較基準として、Dense層だけを使ったシンプルなモデルを構築し、その後、それより小さいバージョンと大きいバージョンを作って比較します。

比較基準を作る

baseline_model = keras.Sequential([
    # `.summary` を見るために`input_shape`が必要 
    keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

baseline_model.compile(optimizer='adam',
                       loss='binary_crossentropy',
                       metrics=['accuracy', 'binary_crossentropy'])

baseline_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                160016    
_________________________________________________________________
dense_1 (Dense)              (None, 16)                272       
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 17        
=================================================================
Total params: 160,305
Trainable params: 160,305
Non-trainable params: 0
_________________________________________________________________

baseline_history = baseline_model.fit(train_data,
                                      train_labels,
                                      epochs=20,
                                      batch_size=512,
                                      validation_data=(test_data, test_labels),
                                      verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.4567 - accuracy: 0.8121 - binary_crossentropy: 0.4567 - val_loss: 0.3129 - val_accuracy: 0.8799 - val_binary_crossentropy: 0.3129
Epoch 2/20
49/49 - 1s - loss: 0.2287 - accuracy: 0.9170 - binary_crossentropy: 0.2287 - val_loss: 0.2820 - val_accuracy: 0.8876 - val_binary_crossentropy: 0.2820
Epoch 3/20
49/49 - 1s - loss: 0.1672 - accuracy: 0.9408 - binary_crossentropy: 0.1672 - val_loss: 0.3008 - val_accuracy: 0.8807 - val_binary_crossentropy: 0.3008
Epoch 4/20
49/49 - 1s - loss: 0.1329 - accuracy: 0.9550 - binary_crossentropy: 0.1329 - val_loss: 0.3331 - val_accuracy: 0.8756 - val_binary_crossentropy: 0.3331
Epoch 5/20
49/49 - 1s - loss: 0.1092 - accuracy: 0.9641 - binary_crossentropy: 0.1092 - val_loss: 0.3744 - val_accuracy: 0.8684 - val_binary_crossentropy: 0.3744
Epoch 6/20
49/49 - 1s - loss: 0.0888 - accuracy: 0.9712 - binary_crossentropy: 0.0888 - val_loss: 0.4092 - val_accuracy: 0.8630 - val_binary_crossentropy: 0.4092
Epoch 7/20
49/49 - 1s - loss: 0.0720 - accuracy: 0.9788 - binary_crossentropy: 0.0720 - val_loss: 0.4563 - val_accuracy: 0.8593 - val_binary_crossentropy: 0.4563
Epoch 8/20
49/49 - 1s - loss: 0.0578 - accuracy: 0.9846 - binary_crossentropy: 0.0578 - val_loss: 0.4950 - val_accuracy: 0.8608 - val_binary_crossentropy: 0.4950
Epoch 9/20
49/49 - 1s - loss: 0.0459 - accuracy: 0.9900 - binary_crossentropy: 0.0459 - val_loss: 0.5479 - val_accuracy: 0.8564 - val_binary_crossentropy: 0.5479
Epoch 10/20
49/49 - 1s - loss: 0.0386 - accuracy: 0.9916 - binary_crossentropy: 0.0386 - val_loss: 0.5939 - val_accuracy: 0.8543 - val_binary_crossentropy: 0.5939
Epoch 11/20
49/49 - 1s - loss: 0.0287 - accuracy: 0.9951 - binary_crossentropy: 0.0287 - val_loss: 0.6396 - val_accuracy: 0.8538 - val_binary_crossentropy: 0.6396
Epoch 12/20
49/49 - 1s - loss: 0.0222 - accuracy: 0.9970 - binary_crossentropy: 0.0222 - val_loss: 0.6908 - val_accuracy: 0.8518 - val_binary_crossentropy: 0.6908
Epoch 13/20
49/49 - 1s - loss: 0.0166 - accuracy: 0.9986 - binary_crossentropy: 0.0166 - val_loss: 0.7289 - val_accuracy: 0.8528 - val_binary_crossentropy: 0.7289
Epoch 14/20
49/49 - 1s - loss: 0.0123 - accuracy: 0.9991 - binary_crossentropy: 0.0123 - val_loss: 0.7756 - val_accuracy: 0.8508 - val_binary_crossentropy: 0.7756
Epoch 15/20
49/49 - 1s - loss: 0.0091 - accuracy: 0.9995 - binary_crossentropy: 0.0091 - val_loss: 0.8145 - val_accuracy: 0.8510 - val_binary_crossentropy: 0.8145
Epoch 16/20
49/49 - 1s - loss: 0.0069 - accuracy: 0.9999 - binary_crossentropy: 0.0069 - val_loss: 0.8510 - val_accuracy: 0.8502 - val_binary_crossentropy: 0.8510
Epoch 17/20
49/49 - 1s - loss: 0.0056 - accuracy: 1.0000 - binary_crossentropy: 0.0056 - val_loss: 0.8908 - val_accuracy: 0.8496 - val_binary_crossentropy: 0.8908
Epoch 18/20
49/49 - 1s - loss: 0.0046 - accuracy: 1.0000 - binary_crossentropy: 0.0046 - val_loss: 0.9217 - val_accuracy: 0.8496 - val_binary_crossentropy: 0.9217
Epoch 19/20
49/49 - 1s - loss: 0.0037 - accuracy: 1.0000 - binary_crossentropy: 0.0037 - val_loss: 0.9498 - val_accuracy: 0.8497 - val_binary_crossentropy: 0.9498
Epoch 20/20
49/49 - 1s - loss: 0.0031 - accuracy: 1.0000 - binary_crossentropy: 0.0031 - val_loss: 0.9770 - val_accuracy: 0.8497 - val_binary_crossentropy: 0.9770

より小さいモデルの構築

今作成したばかりの比較基準となるモデルに比べて隠れユニット数が少ないモデルを作りましょう。

smaller_model = keras.Sequential([
    keras.layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(4, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

smaller_model.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['accuracy', 'binary_crossentropy'])

smaller_model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 4)                 40004     
_________________________________________________________________
dense_4 (Dense)              (None, 4)                 20        
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 5         
=================================================================
Total params: 40,029
Trainable params: 40,029
Non-trainable params: 0
_________________________________________________________________

同じデータを使って訓練します。

smaller_history = smaller_model.fit(train_data,
                                    train_labels,
                                    epochs=20,
                                    batch_size=512,
                                    validation_data=(test_data, test_labels),
                                    verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.6152 - accuracy: 0.6698 - binary_crossentropy: 0.6152 - val_loss: 0.5395 - val_accuracy: 0.7858 - val_binary_crossentropy: 0.5395
Epoch 2/20
49/49 - 1s - loss: 0.4887 - accuracy: 0.8514 - binary_crossentropy: 0.4887 - val_loss: 0.4880 - val_accuracy: 0.8325 - val_binary_crossentropy: 0.4880
Epoch 3/20
49/49 - 1s - loss: 0.4361 - accuracy: 0.8972 - binary_crossentropy: 0.4361 - val_loss: 0.4660 - val_accuracy: 0.8502 - val_binary_crossentropy: 0.4660
Epoch 4/20
49/49 - 1s - loss: 0.4016 - accuracy: 0.9202 - binary_crossentropy: 0.4016 - val_loss: 0.4513 - val_accuracy: 0.8670 - val_binary_crossentropy: 0.4513
Epoch 5/20
49/49 - 1s - loss: 0.3750 - accuracy: 0.9352 - binary_crossentropy: 0.3750 - val_loss: 0.4456 - val_accuracy: 0.8670 - val_binary_crossentropy: 0.4456
Epoch 6/20
49/49 - 1s - loss: 0.3523 - accuracy: 0.9462 - binary_crossentropy: 0.3523 - val_loss: 0.4415 - val_accuracy: 0.8686 - val_binary_crossentropy: 0.4415
Epoch 7/20
49/49 - 1s - loss: 0.3328 - accuracy: 0.9553 - binary_crossentropy: 0.3328 - val_loss: 0.4484 - val_accuracy: 0.8626 - val_binary_crossentropy: 0.4484
Epoch 8/20
49/49 - 1s - loss: 0.3152 - accuracy: 0.9605 - binary_crossentropy: 0.3152 - val_loss: 0.4415 - val_accuracy: 0.8676 - val_binary_crossentropy: 0.4415
Epoch 9/20
49/49 - 1s - loss: 0.2992 - accuracy: 0.9658 - binary_crossentropy: 0.2992 - val_loss: 0.4402 - val_accuracy: 0.8683 - val_binary_crossentropy: 0.4402
Epoch 10/20
49/49 - 1s - loss: 0.2851 - accuracy: 0.9702 - binary_crossentropy: 0.2851 - val_loss: 0.4474 - val_accuracy: 0.8660 - val_binary_crossentropy: 0.4474
Epoch 11/20
49/49 - 1s - loss: 0.2715 - accuracy: 0.9730 - binary_crossentropy: 0.2715 - val_loss: 0.4585 - val_accuracy: 0.8636 - val_binary_crossentropy: 0.4585
Epoch 12/20
49/49 - 1s - loss: 0.2594 - accuracy: 0.9759 - binary_crossentropy: 0.2594 - val_loss: 0.4703 - val_accuracy: 0.8604 - val_binary_crossentropy: 0.4703
Epoch 13/20
49/49 - 1s - loss: 0.2480 - accuracy: 0.9782 - binary_crossentropy: 0.2480 - val_loss: 0.4673 - val_accuracy: 0.8622 - val_binary_crossentropy: 0.4673
Epoch 14/20
49/49 - 1s - loss: 0.2374 - accuracy: 0.9798 - binary_crossentropy: 0.2374 - val_loss: 0.4703 - val_accuracy: 0.8623 - val_binary_crossentropy: 0.4703
Epoch 15/20
49/49 - 1s - loss: 0.2277 - accuracy: 0.9817 - binary_crossentropy: 0.2277 - val_loss: 0.4831 - val_accuracy: 0.8609 - val_binary_crossentropy: 0.4831
Epoch 16/20
49/49 - 1s - loss: 0.2193 - accuracy: 0.9824 - binary_crossentropy: 0.2193 - val_loss: 0.5015 - val_accuracy: 0.8594 - val_binary_crossentropy: 0.5015
Epoch 17/20
49/49 - 1s - loss: 0.2107 - accuracy: 0.9835 - binary_crossentropy: 0.2107 - val_loss: 0.5039 - val_accuracy: 0.8592 - val_binary_crossentropy: 0.5039
Epoch 18/20
49/49 - 1s - loss: 0.2031 - accuracy: 0.9839 - binary_crossentropy: 0.2031 - val_loss: 0.5137 - val_accuracy: 0.8583 - val_binary_crossentropy: 0.5137
Epoch 19/20
49/49 - 1s - loss: 0.1955 - accuracy: 0.9850 - binary_crossentropy: 0.1955 - val_loss: 0.5143 - val_accuracy: 0.8575 - val_binary_crossentropy: 0.5143
Epoch 20/20
49/49 - 1s - loss: 0.1892 - accuracy: 0.9850 - binary_crossentropy: 0.1892 - val_loss: 0.5450 - val_accuracy: 0.8560 - val_binary_crossentropy: 0.5450

より大きなモデルの構築

練習として、より大きなモデルを作成し、どれほど急速に過学習が起きるかを見ることもできます。次はこのベンチマークに、この問題が必要とするよりはるかに容量の大きなネットワークを追加しましょう。

bigger_model = keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(512, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

bigger_model.compile(optimizer='adam',
                     loss='binary_crossentropy',
                     metrics=['accuracy','binary_crossentropy'])

bigger_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_6 (Dense)              (None, 512)               5120512   
_________________________________________________________________
dense_7 (Dense)              (None, 512)               262656    
_________________________________________________________________
dense_8 (Dense)              (None, 1)                 513       
=================================================================
Total params: 5,383,681
Trainable params: 5,383,681
Non-trainable params: 0
_________________________________________________________________

このモデルもまた同じデータを使って訓練します。

bigger_history = bigger_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.3408 - accuracy: 0.8579 - binary_crossentropy: 0.3408 - val_loss: 0.2921 - val_accuracy: 0.8800 - val_binary_crossentropy: 0.2921
Epoch 2/20
49/49 - 1s - loss: 0.1394 - accuracy: 0.9494 - binary_crossentropy: 0.1394 - val_loss: 0.3408 - val_accuracy: 0.8747 - val_binary_crossentropy: 0.3408
Epoch 3/20
49/49 - 1s - loss: 0.0455 - accuracy: 0.9860 - binary_crossentropy: 0.0455 - val_loss: 0.4624 - val_accuracy: 0.8671 - val_binary_crossentropy: 0.4624
Epoch 4/20
49/49 - 1s - loss: 0.0064 - accuracy: 0.9992 - binary_crossentropy: 0.0064 - val_loss: 0.5888 - val_accuracy: 0.8685 - val_binary_crossentropy: 0.5888
Epoch 5/20
49/49 - 1s - loss: 8.7914e-04 - accuracy: 1.0000 - binary_crossentropy: 8.7914e-04 - val_loss: 0.6871 - val_accuracy: 0.8697 - val_binary_crossentropy: 0.6871
Epoch 6/20
49/49 - 1s - loss: 2.4132e-04 - accuracy: 1.0000 - binary_crossentropy: 2.4132e-04 - val_loss: 0.7388 - val_accuracy: 0.8697 - val_binary_crossentropy: 0.7388
Epoch 7/20
49/49 - 1s - loss: 1.4389e-04 - accuracy: 1.0000 - binary_crossentropy: 1.4389e-04 - val_loss: 0.7693 - val_accuracy: 0.8703 - val_binary_crossentropy: 0.7693
Epoch 8/20
49/49 - 1s - loss: 1.0339e-04 - accuracy: 1.0000 - binary_crossentropy: 1.0339e-04 - val_loss: 0.7931 - val_accuracy: 0.8701 - val_binary_crossentropy: 0.7931
Epoch 9/20
49/49 - 1s - loss: 7.8790e-05 - accuracy: 1.0000 - binary_crossentropy: 7.8790e-05 - val_loss: 0.8119 - val_accuracy: 0.8704 - val_binary_crossentropy: 0.8119
Epoch 10/20
49/49 - 1s - loss: 6.2591e-05 - accuracy: 1.0000 - binary_crossentropy: 6.2591e-05 - val_loss: 0.8282 - val_accuracy: 0.8709 - val_binary_crossentropy: 0.8282
Epoch 11/20
49/49 - 1s - loss: 5.1024e-05 - accuracy: 1.0000 - binary_crossentropy: 5.1024e-05 - val_loss: 0.8430 - val_accuracy: 0.8707 - val_binary_crossentropy: 0.8430
Epoch 12/20
49/49 - 1s - loss: 4.2361e-05 - accuracy: 1.0000 - binary_crossentropy: 4.2361e-05 - val_loss: 0.8569 - val_accuracy: 0.8701 - val_binary_crossentropy: 0.8569
Epoch 13/20
49/49 - 1s - loss: 3.5866e-05 - accuracy: 1.0000 - binary_crossentropy: 3.5866e-05 - val_loss: 0.8688 - val_accuracy: 0.8705 - val_binary_crossentropy: 0.8688
Epoch 14/20
49/49 - 1s - loss: 3.0626e-05 - accuracy: 1.0000 - binary_crossentropy: 3.0626e-05 - val_loss: 0.8802 - val_accuracy: 0.8703 - val_binary_crossentropy: 0.8802
Epoch 15/20
49/49 - 1s - loss: 2.6531e-05 - accuracy: 1.0000 - binary_crossentropy: 2.6531e-05 - val_loss: 0.8906 - val_accuracy: 0.8704 - val_binary_crossentropy: 0.8906
Epoch 16/20
49/49 - 1s - loss: 2.3148e-05 - accuracy: 1.0000 - binary_crossentropy: 2.3148e-05 - val_loss: 0.9005 - val_accuracy: 0.8705 - val_binary_crossentropy: 0.9005
Epoch 17/20
49/49 - 1s - loss: 2.0370e-05 - accuracy: 1.0000 - binary_crossentropy: 2.0370e-05 - val_loss: 0.9099 - val_accuracy: 0.8704 - val_binary_crossentropy: 0.9099
Epoch 18/20
49/49 - 1s - loss: 1.8041e-05 - accuracy: 1.0000 - binary_crossentropy: 1.8041e-05 - val_loss: 0.9186 - val_accuracy: 0.8704 - val_binary_crossentropy: 0.9186
Epoch 19/20
49/49 - 1s - loss: 1.6083e-05 - accuracy: 1.0000 - binary_crossentropy: 1.6083e-05 - val_loss: 0.9265 - val_accuracy: 0.8703 - val_binary_crossentropy: 0.9265
Epoch 20/20
49/49 - 1s - loss: 1.4427e-05 - accuracy: 1.0000 - binary_crossentropy: 1.4427e-05 - val_loss: 0.9351 - val_accuracy: 0.8706 - val_binary_crossentropy: 0.9351

訓練時と検証時の損失をグラフにする

実線は訓練用データセットの損失、破線は検証用データセットでの損失です(検証用データでの損失が小さい方が良いモデルです)。これをみると、小さいネットワークのほうが比較基準のモデルよりも過学習が始まるのが遅いことがわかります(4エポックではなく6エポック後)。また、過学習が始まっても性能の低下がよりゆっくりしています。

def plot_history(histories, key='binary_crossentropy'):
  plt.figure(figsize=(16,10))
    
  for name, history in histories:
    val = plt.plot(history.epoch, history.history['val_'+key],
                   '--', label=name.title()+' Val')
    plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
             label=name.title()+' Train')

  plt.xlabel('Epochs')
  plt.ylabel(key.replace('_',' ').title())
  plt.legend()

  plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
              ('smaller', smaller_history),
              ('bigger', bigger_history)])

png

より大きなネットワークでは、すぐに、1エポックで過学習が始まり、その度合も強いことに注目してください。ネットワークの容量が大きいほど訓練用データをモデル化するスピードが早くなり(結果として訓練時の損失値が小さくなり)ますが、より過学習しやすく(結果として訓練時の損失値と検証時の損失値が大きく乖離しやすく)なります。

過学習防止の戦略

重みの正則化を加える

「オッカムの剃刀」の原則をご存知でしょうか。何かの説明が2つあるとすると、最も正しいと考えられる説明は、仮定の数が最も少ない「一番単純な」説明だというものです。この原則は、ニューラルネットワークを使って学習されたモデルにも当てはまります。ある訓練用データとネットワーク構造があって、そのデータを説明できる重みの集合が複数ある時(つまり、複数のモデルがある時)、単純なモデルのほうが複雑なものよりも過学習しにくいのです。

ここで言う「単純なモデル」とは、パラメータ値の分布のエントロピーが小さいもの(あるいは、上記で見たように、そもそもパラメータの数が少ないもの)です。したがって、過学習を緩和するための一般的な手法は、重みが小さい値のみをとることで、重み値の分布がより整然となる(正則)様に制約を与えるものです。これを「重みの正則化」と呼ばれ、ネットワークの損失関数に、重みの大きさに関連するコストを加えることで行われます。このコストには2つの種類があります。

  • L1正則化 重み係数の絶対値に比例するコストを加える(重みの「L1ノルム」と呼ばれる)。

  • L2正則化 重み係数の二乗に比例するコストを加える(重み係数の二乗「L2ノルム」と呼ばれる)。L2正則化はニューラルネットワーク用語では重み減衰(Weight Decay)と呼ばれる。呼び方が違うので混乱しないように。重み減衰は数学的にはL2正則化と同義である。

L1正則化は重みパラメータの一部を0にすることでモデルを疎にする効果があります。L2正則化は重みパラメータにペナルティを加えますがモデルを疎にすることはありません。これは、L2正則化のほうが一般的である理由の一つです。

tf.kerasでは、重みの正則化をするために、重み正則化のインスタンスをキーワード引数として層に加えます。ここでは、L2正則化を追加してみましょう。

l2_model = keras.models.Sequential([
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

l2_model.compile(optimizer='adam',
                 loss='binary_crossentropy',
                 metrics=['accuracy', 'binary_crossentropy'])

l2_model_history = l2_model.fit(train_data, train_labels,
                                epochs=20,
                                batch_size=512,
                                validation_data=(test_data, test_labels),
                                verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.5246 - accuracy: 0.8037 - binary_crossentropy: 0.4841 - val_loss: 0.3787 - val_accuracy: 0.8750 - val_binary_crossentropy: 0.3357
Epoch 2/20
49/49 - 1s - loss: 0.3018 - accuracy: 0.9074 - binary_crossentropy: 0.2548 - val_loss: 0.3335 - val_accuracy: 0.8874 - val_binary_crossentropy: 0.2837
Epoch 3/20
49/49 - 1s - loss: 0.2499 - accuracy: 0.9305 - binary_crossentropy: 0.1980 - val_loss: 0.3392 - val_accuracy: 0.8852 - val_binary_crossentropy: 0.2856
Epoch 4/20
49/49 - 1s - loss: 0.2261 - accuracy: 0.9422 - binary_crossentropy: 0.1710 - val_loss: 0.3508 - val_accuracy: 0.8828 - val_binary_crossentropy: 0.2944
Epoch 5/20
49/49 - 1s - loss: 0.2116 - accuracy: 0.9475 - binary_crossentropy: 0.1541 - val_loss: 0.3660 - val_accuracy: 0.8783 - val_binary_crossentropy: 0.3077
Epoch 6/20
49/49 - 1s - loss: 0.2012 - accuracy: 0.9525 - binary_crossentropy: 0.1419 - val_loss: 0.3820 - val_accuracy: 0.8748 - val_binary_crossentropy: 0.3221
Epoch 7/20
49/49 - 1s - loss: 0.1941 - accuracy: 0.9541 - binary_crossentropy: 0.1333 - val_loss: 0.3986 - val_accuracy: 0.8731 - val_binary_crossentropy: 0.3372
Epoch 8/20
49/49 - 1s - loss: 0.1857 - accuracy: 0.9584 - binary_crossentropy: 0.1237 - val_loss: 0.4137 - val_accuracy: 0.8697 - val_binary_crossentropy: 0.3514
Epoch 9/20
49/49 - 1s - loss: 0.1814 - accuracy: 0.9601 - binary_crossentropy: 0.1186 - val_loss: 0.4293 - val_accuracy: 0.8663 - val_binary_crossentropy: 0.3659
Epoch 10/20
49/49 - 1s - loss: 0.1786 - accuracy: 0.9610 - binary_crossentropy: 0.1144 - val_loss: 0.4422 - val_accuracy: 0.8663 - val_binary_crossentropy: 0.3776
Epoch 11/20
49/49 - 1s - loss: 0.1701 - accuracy: 0.9651 - binary_crossentropy: 0.1053 - val_loss: 0.4562 - val_accuracy: 0.8657 - val_binary_crossentropy: 0.3913
Epoch 12/20
49/49 - 1s - loss: 0.1682 - accuracy: 0.9655 - binary_crossentropy: 0.1029 - val_loss: 0.4710 - val_accuracy: 0.8627 - val_binary_crossentropy: 0.4053
Epoch 13/20
49/49 - 1s - loss: 0.1675 - accuracy: 0.9653 - binary_crossentropy: 0.1015 - val_loss: 0.5021 - val_accuracy: 0.8555 - val_binary_crossentropy: 0.4356
Epoch 14/20
49/49 - 1s - loss: 0.1658 - accuracy: 0.9656 - binary_crossentropy: 0.0984 - val_loss: 0.5031 - val_accuracy: 0.8593 - val_binary_crossentropy: 0.4354
Epoch 15/20
49/49 - 1s - loss: 0.1604 - accuracy: 0.9685 - binary_crossentropy: 0.0925 - val_loss: 0.5150 - val_accuracy: 0.8596 - val_binary_crossentropy: 0.4468
Epoch 16/20
49/49 - 1s - loss: 0.1581 - accuracy: 0.9695 - binary_crossentropy: 0.0897 - val_loss: 0.5274 - val_accuracy: 0.8580 - val_binary_crossentropy: 0.4589
Epoch 17/20
49/49 - 1s - loss: 0.1520 - accuracy: 0.9724 - binary_crossentropy: 0.0835 - val_loss: 0.5428 - val_accuracy: 0.8575 - val_binary_crossentropy: 0.4745
Epoch 18/20
49/49 - 1s - loss: 0.1521 - accuracy: 0.9717 - binary_crossentropy: 0.0833 - val_loss: 0.5467 - val_accuracy: 0.8573 - val_binary_crossentropy: 0.4775
Epoch 19/20
49/49 - 1s - loss: 0.1451 - accuracy: 0.9747 - binary_crossentropy: 0.0761 - val_loss: 0.5622 - val_accuracy: 0.8565 - val_binary_crossentropy: 0.4930
Epoch 20/20
49/49 - 1s - loss: 0.1421 - accuracy: 0.9767 - binary_crossentropy: 0.0726 - val_loss: 0.5815 - val_accuracy: 0.8528 - val_binary_crossentropy: 0.5120

l2(0.001)というのは、層の重み行列の係数全てに対して0.001 * 重み係数の値 **2をネットワークの損失値合計に加えることを意味します。このペナルティは訓練時のみに加えられるため、このネットワークの損失値は、訓練時にはテスト時に比べて大きくなることに注意してください。

L2正則化の影響を見てみましょう。

plot_history([('baseline', baseline_history),
              ('l2', l2_model_history)])

png

ご覧のように、L2正則化ありのモデルは比較基準のモデルに比べて過学習しにくくなっています。両方のモデルのパラメータ数は同じであるにもかかわらずです。

ドロップアウトを追加する

ドロップアウトは、ニューラルネットワークの正則化テクニックとして最もよく使われる手法の一つです。この手法は、トロント大学のヒントンと彼の学生が開発したものです。ドロップアウトは層に適用するもので、訓練時に層から出力された特徴量に対してランダムに「ドロップアウト(つまりゼロ化)」を行うものです。例えば、ある層が訓練時にある入力サンプルに対して、普通は[0.2, 0.5, 1.3, 0.8, 1.1] というベクトルを出力するとします。ドロップアウトを適用すると、このベクトルは例えば[0, 0.5, 1.3, 0, 1.1]のようにランダムに散らばったいくつかのゼロを含むようになります。「ドロップアウト率」はゼロ化される特徴の割合で、通常は0.2から0.5の間に設定します。テスト時は、どのユニットもドロップアウトされず、代わりに出力値がドロップアウト率と同じ比率でスケールダウンされます。これは、訓練時に比べてたくさんのユニットがアクティブであることに対してバランスをとるためです。

tf.kerasでは、Dropout層を使ってドロップアウトをネットワークに導入できます。ドロップアウト層は、その直前の層の出力に対してドロップアウトを適用します。

それでは、IMDBネットワークに2つのドロップアウト層を追加しましょう。

dpt_model = keras.models.Sequential([
    keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1, activation='sigmoid')
])

dpt_model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy','binary_crossentropy'])

dpt_model_history = dpt_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.6479 - accuracy: 0.6178 - binary_crossentropy: 0.6479 - val_loss: 0.5360 - val_accuracy: 0.8455 - val_binary_crossentropy: 0.5360
Epoch 2/20
49/49 - 1s - loss: 0.4987 - accuracy: 0.7832 - binary_crossentropy: 0.4987 - val_loss: 0.3744 - val_accuracy: 0.8809 - val_binary_crossentropy: 0.3744
Epoch 3/20
49/49 - 1s - loss: 0.3833 - accuracy: 0.8589 - binary_crossentropy: 0.3833 - val_loss: 0.3040 - val_accuracy: 0.8865 - val_binary_crossentropy: 0.3040
Epoch 4/20
49/49 - 1s - loss: 0.3126 - accuracy: 0.8928 - binary_crossentropy: 0.3126 - val_loss: 0.2813 - val_accuracy: 0.8881 - val_binary_crossentropy: 0.2813
Epoch 5/20
49/49 - 1s - loss: 0.2639 - accuracy: 0.9144 - binary_crossentropy: 0.2639 - val_loss: 0.2771 - val_accuracy: 0.8902 - val_binary_crossentropy: 0.2771
Epoch 6/20
49/49 - 1s - loss: 0.2312 - accuracy: 0.9280 - binary_crossentropy: 0.2312 - val_loss: 0.2829 - val_accuracy: 0.8879 - val_binary_crossentropy: 0.2829
Epoch 7/20
49/49 - 1s - loss: 0.2045 - accuracy: 0.9370 - binary_crossentropy: 0.2045 - val_loss: 0.2973 - val_accuracy: 0.8864 - val_binary_crossentropy: 0.2973
Epoch 8/20
49/49 - 1s - loss: 0.1888 - accuracy: 0.9432 - binary_crossentropy: 0.1888 - val_loss: 0.3168 - val_accuracy: 0.8848 - val_binary_crossentropy: 0.3168
Epoch 9/20
49/49 - 1s - loss: 0.1694 - accuracy: 0.9484 - binary_crossentropy: 0.1694 - val_loss: 0.3214 - val_accuracy: 0.8818 - val_binary_crossentropy: 0.3214
Epoch 10/20
49/49 - 1s - loss: 0.1569 - accuracy: 0.9540 - binary_crossentropy: 0.1569 - val_loss: 0.3447 - val_accuracy: 0.8826 - val_binary_crossentropy: 0.3447
Epoch 11/20
49/49 - 1s - loss: 0.1403 - accuracy: 0.9581 - binary_crossentropy: 0.1403 - val_loss: 0.3523 - val_accuracy: 0.8797 - val_binary_crossentropy: 0.3523
Epoch 12/20
49/49 - 1s - loss: 0.1310 - accuracy: 0.9619 - binary_crossentropy: 0.1310 - val_loss: 0.3879 - val_accuracy: 0.8788 - val_binary_crossentropy: 0.3879
Epoch 13/20
49/49 - 1s - loss: 0.1182 - accuracy: 0.9650 - binary_crossentropy: 0.1182 - val_loss: 0.3959 - val_accuracy: 0.8786 - val_binary_crossentropy: 0.3959
Epoch 14/20
49/49 - 1s - loss: 0.1108 - accuracy: 0.9689 - binary_crossentropy: 0.1108 - val_loss: 0.4203 - val_accuracy: 0.8785 - val_binary_crossentropy: 0.4203
Epoch 15/20
49/49 - 1s - loss: 0.1024 - accuracy: 0.9701 - binary_crossentropy: 0.1024 - val_loss: 0.4584 - val_accuracy: 0.8756 - val_binary_crossentropy: 0.4584
Epoch 16/20
49/49 - 1s - loss: 0.0958 - accuracy: 0.9731 - binary_crossentropy: 0.0958 - val_loss: 0.4551 - val_accuracy: 0.8769 - val_binary_crossentropy: 0.4551
Epoch 17/20
49/49 - 1s - loss: 0.0924 - accuracy: 0.9726 - binary_crossentropy: 0.0924 - val_loss: 0.4788 - val_accuracy: 0.8755 - val_binary_crossentropy: 0.4788
Epoch 18/20
49/49 - 1s - loss: 0.0893 - accuracy: 0.9727 - binary_crossentropy: 0.0893 - val_loss: 0.4946 - val_accuracy: 0.8766 - val_binary_crossentropy: 0.4946
Epoch 19/20
49/49 - 1s - loss: 0.0801 - accuracy: 0.9766 - binary_crossentropy: 0.0801 - val_loss: 0.5307 - val_accuracy: 0.8765 - val_binary_crossentropy: 0.5307
Epoch 20/20
49/49 - 1s - loss: 0.0799 - accuracy: 0.9762 - binary_crossentropy: 0.0799 - val_loss: 0.5597 - val_accuracy: 0.8744 - val_binary_crossentropy: 0.5597

plot_history([('baseline', baseline_history),
              ('dropout', dpt_model_history)])

png

ドロップアウトを追加することで、比較対象モデルより明らかに改善が見られます。

まとめ:ニューラルネットワークにおいて過学習を防ぐ最も一般的な方法は次のとおりです。

  • 訓練データを増やす
  • ネットワークの容量をへらす
  • 重みの正則化を行う
  • ドロップアウトを追加する

このガイドで触れていない2つの重要なアプローチがあります。データ拡張とバッチ正規化です。


#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.