過学習と学習不足について知る

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

いつものように、この例のプログラムはtf.keras APIを使用します。詳しくはTensorFlowのKeras guideを参照してください。

これまでの例、つまり、映画レビューの分類と燃費の推定では、検証用データでのモデルの正解率が、数エポックでピークを迎え、その後低下するという現象が見られました。

言い換えると、モデルが訓練用データを過学習したと考えられます。過学習への対処の仕方を学ぶことは重要です。訓練用データセットで高い正解率を達成することは難しくありませんが、我々は、(これまで見たこともない)テスト用データに汎化したモデルを開発したいのです。

過学習の反対語は学習不足(underfitting)です。学習不足は、モデルがテストデータに対してまだ改善の余地がある場合に発生します。学習不足の原因は様々です。モデルが十分強力でないとか、正則化のしすぎだとか、単に訓練時間が短すぎるといった理由があります。学習不足は、訓練用データの中の関連したパターンを学習しきっていないということを意味します。

モデルの訓練をやりすぎると、モデルは過学習を始め、訓練用データの中のパターンで、テストデータには一般的ではないパターンを学習します。我々は、過学習と学習不足の中間を目指す必要があります。これから見ていくように、ちょうどよいエポック数だけ訓練を行うというのは必要なスキルなのです。

過学習を防止するための、最良の解決策は、より多くの訓練用データを使うことです。多くのデータで訓練を行えば行うほど、モデルは自然により汎化していく様になります。これが不可能な場合、次善の策は正則化のようなテクニックを使うことです。正則化は、モデルに保存される情報の量とタイプに制約を課すものです。ネットワークが少数のパターンしか記憶できなければ、最適化プロセスにより、最も主要なパターンのみを学習することになり、より汎化される可能性が高くなります。

このノートブックでは、重みの正則化とドロップアウトという、よく使われる2つの正則化テクニックをご紹介します。これらを使って、IMDBの映画レビューを分類するノートブックの改善を図ります。

import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
2.4.1

IMDBデータセットのダウンロード

以前のノートブックで使用したエンベディングの代わりに、ここでは文をマルチホットエンコードします。このモデルは、訓練用データセットをすぐに過学習します。このモデルを使って、過学習がいつ起きるかということと、どうやって過学習と戦うかをデモします。

リストをマルチホットエンコードすると言うのは、0と1のベクトルにするということです。具体的にいうと、例えば[3, 5]というシーケンスを、インデックス3と5の値が1で、それ以外がすべて0の、10,000次元のベクトルに変換するということを意味します。

NUM_WORDS = 10000

(train_data, train_labels), (test_data, test_labels) = keras.datasets.imdb.load_data(num_words=NUM_WORDS)

def multi_hot_sequences(sequences, dimension):
    # 形状が (len(sequences), dimension)ですべて0の行列を作る
    results = np.zeros((len(sequences), dimension))
    for i, word_indices in enumerate(sequences):
        results[i, word_indices] = 1.0  # 特定のインデックスに対してresults[i] を1に設定する
    return results


train_data = multi_hot_sequences(train_data, dimension=NUM_WORDS)
test_data = multi_hot_sequences(test_data, dimension=NUM_WORDS)
<string>:6: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/datasets/imdb.py:159: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/datasets/imdb.py:160: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])

結果として得られるマルチホットベクトルの1つを見てみましょう。単語のインデックスは頻度順にソートされています。このため、インデックスが0に近いほど1が多く出現するはずです。分布を見てみましょう。

plt.plot(train_data[0])
[<matplotlib.lines.Line2D at 0x7f0986e77f98>]

png

過学習のデモ

過学習を防止するための最も単純な方法は、モデルのサイズ、すなわち、モデル内の学習可能なパラメータの数を小さくすることです(学習パラメータの数は、層の数と層ごとのユニット数で決まります)。ディープラーニングでは、モデルの学習可能なパラメータ数を、しばしばモデルの「キャパシティ」と呼びます。直感的に考えれば、パラメータ数の多いモデルほど「記憶容量」が大きくなり、訓練用のサンプルとその目的変数の間の辞書のようなマッピングをたやすく学習することができます。このマッピングには汎化能力がまったくなく、これまで見たことが無いデータを使って予測をする際には役に立ちません。

ディープラーニングのモデルは訓練用データに適応しやすいけれど、本当のチャレレンジは汎化であって適応ではないということを、肝に銘じておく必要があります。

一方、ネットワークの記憶容量が限られている場合、前述のようなマッピングを簡単に学習することはできません。損失を減らすためには、より予測能力が高い圧縮された表現を学習しなければなりません。同時に、モデルを小さくしすぎると、訓練用データに適応するのが難しくなります。「多すぎる容量」と「容量不足」の間にちょうどよい容量があるのです。

残念ながら、(層の数や、層ごとの大きさといった)モデルの適切なサイズやアーキテクチャを決める魔法の方程式はありません。一連の異なるアーキテクチャを使って実験を行う必要があります。

適切なモデルのサイズを見つけるには、比較的少ない層の数とパラメータから始めるのがベストです。それから、検証用データでの損失値の改善が見られなくなるまで、徐々に層の大きさを増やしたり、新たな層を加えたりします。映画レビューの分類ネットワークでこれを試してみましょう。

比較基準として、Dense層だけを使ったシンプルなモデルを構築し、その後、それより小さいバージョンと大きいバージョンを作って比較します。

比較基準を作る

baseline_model = keras.Sequential([
    # `.summary` を見るために`input_shape`が必要 
    keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

baseline_model.compile(optimizer='adam',
                       loss='binary_crossentropy',
                       metrics=['accuracy', 'binary_crossentropy'])

baseline_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                160016    
_________________________________________________________________
dense_1 (Dense)              (None, 16)                272       
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 17        
=================================================================
Total params: 160,305
Trainable params: 160,305
Non-trainable params: 0
_________________________________________________________________
baseline_history = baseline_model.fit(train_data,
                                      train_labels,
                                      epochs=20,
                                      batch_size=512,
                                      validation_data=(test_data, test_labels),
                                      verbose=2)
Epoch 1/20
49/49 - 3s - loss: 0.5014 - accuracy: 0.7904 - binary_crossentropy: 0.5014 - val_loss: 0.3434 - val_accuracy: 0.8751 - val_binary_crossentropy: 0.3434
Epoch 2/20
49/49 - 1s - loss: 0.2533 - accuracy: 0.9106 - binary_crossentropy: 0.2533 - val_loss: 0.2832 - val_accuracy: 0.8877 - val_binary_crossentropy: 0.2832
Epoch 3/20
49/49 - 1s - loss: 0.1815 - accuracy: 0.9362 - binary_crossentropy: 0.1815 - val_loss: 0.2885 - val_accuracy: 0.8864 - val_binary_crossentropy: 0.2885
Epoch 4/20
49/49 - 1s - loss: 0.1420 - accuracy: 0.9520 - binary_crossentropy: 0.1420 - val_loss: 0.3159 - val_accuracy: 0.8772 - val_binary_crossentropy: 0.3159
Epoch 5/20
49/49 - 1s - loss: 0.1141 - accuracy: 0.9634 - binary_crossentropy: 0.1141 - val_loss: 0.3404 - val_accuracy: 0.8749 - val_binary_crossentropy: 0.3404
Epoch 6/20
49/49 - 1s - loss: 0.0922 - accuracy: 0.9711 - binary_crossentropy: 0.0922 - val_loss: 0.3755 - val_accuracy: 0.8692 - val_binary_crossentropy: 0.3755
Epoch 7/20
49/49 - 1s - loss: 0.0729 - accuracy: 0.9800 - binary_crossentropy: 0.0729 - val_loss: 0.4150 - val_accuracy: 0.8656 - val_binary_crossentropy: 0.4150
Epoch 8/20
49/49 - 1s - loss: 0.0562 - accuracy: 0.9867 - binary_crossentropy: 0.0562 - val_loss: 0.4619 - val_accuracy: 0.8627 - val_binary_crossentropy: 0.4619
Epoch 9/20
49/49 - 1s - loss: 0.0440 - accuracy: 0.9911 - binary_crossentropy: 0.0440 - val_loss: 0.5012 - val_accuracy: 0.8588 - val_binary_crossentropy: 0.5012
Epoch 10/20
49/49 - 1s - loss: 0.0329 - accuracy: 0.9944 - binary_crossentropy: 0.0329 - val_loss: 0.5493 - val_accuracy: 0.8589 - val_binary_crossentropy: 0.5493
Epoch 11/20
49/49 - 1s - loss: 0.0240 - accuracy: 0.9968 - binary_crossentropy: 0.0240 - val_loss: 0.5859 - val_accuracy: 0.8568 - val_binary_crossentropy: 0.5859
Epoch 12/20
49/49 - 1s - loss: 0.0171 - accuracy: 0.9987 - binary_crossentropy: 0.0171 - val_loss: 0.6230 - val_accuracy: 0.8561 - val_binary_crossentropy: 0.6230
Epoch 13/20
49/49 - 1s - loss: 0.0126 - accuracy: 0.9990 - binary_crossentropy: 0.0126 - val_loss: 0.6698 - val_accuracy: 0.8544 - val_binary_crossentropy: 0.6698
Epoch 14/20
49/49 - 1s - loss: 0.0090 - accuracy: 0.9996 - binary_crossentropy: 0.0090 - val_loss: 0.6986 - val_accuracy: 0.8534 - val_binary_crossentropy: 0.6986
Epoch 15/20
49/49 - 1s - loss: 0.0068 - accuracy: 0.9998 - binary_crossentropy: 0.0068 - val_loss: 0.7289 - val_accuracy: 0.8534 - val_binary_crossentropy: 0.7289
Epoch 16/20
49/49 - 1s - loss: 0.0052 - accuracy: 0.9999 - binary_crossentropy: 0.0052 - val_loss: 0.7601 - val_accuracy: 0.8535 - val_binary_crossentropy: 0.7601
Epoch 17/20
49/49 - 1s - loss: 0.0041 - accuracy: 0.9999 - binary_crossentropy: 0.0041 - val_loss: 0.7884 - val_accuracy: 0.8531 - val_binary_crossentropy: 0.7884
Epoch 18/20
49/49 - 1s - loss: 0.0033 - accuracy: 1.0000 - binary_crossentropy: 0.0033 - val_loss: 0.8111 - val_accuracy: 0.8538 - val_binary_crossentropy: 0.8111
Epoch 19/20
49/49 - 1s - loss: 0.0027 - accuracy: 1.0000 - binary_crossentropy: 0.0027 - val_loss: 0.8327 - val_accuracy: 0.8528 - val_binary_crossentropy: 0.8327
Epoch 20/20
49/49 - 1s - loss: 0.0023 - accuracy: 1.0000 - binary_crossentropy: 0.0023 - val_loss: 0.8519 - val_accuracy: 0.8529 - val_binary_crossentropy: 0.8519

より小さいモデルの構築

今作成したばかりの比較基準となるモデルに比べて隠れユニット数が少ないモデルを作りましょう。

smaller_model = keras.Sequential([
    keras.layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(4, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

smaller_model.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['accuracy', 'binary_crossentropy'])

smaller_model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 4)                 40004     
_________________________________________________________________
dense_4 (Dense)              (None, 4)                 20        
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 5         
=================================================================
Total params: 40,029
Trainable params: 40,029
Non-trainable params: 0
_________________________________________________________________

同じデータを使って訓練します。

smaller_history = smaller_model.fit(train_data,
                                    train_labels,
                                    epochs=20,
                                    batch_size=512,
                                    validation_data=(test_data, test_labels),
                                    verbose=2)
Epoch 1/20
49/49 - 3s - loss: 0.6453 - accuracy: 0.6029 - binary_crossentropy: 0.6453 - val_loss: 0.5700 - val_accuracy: 0.7488 - val_binary_crossentropy: 0.5700
Epoch 2/20
49/49 - 1s - loss: 0.4972 - accuracy: 0.8303 - binary_crossentropy: 0.4972 - val_loss: 0.4315 - val_accuracy: 0.8690 - val_binary_crossentropy: 0.4315
Epoch 3/20
49/49 - 1s - loss: 0.3058 - accuracy: 0.9089 - binary_crossentropy: 0.3058 - val_loss: 0.3050 - val_accuracy: 0.8854 - val_binary_crossentropy: 0.3050
Epoch 4/20
49/49 - 1s - loss: 0.2190 - accuracy: 0.9283 - binary_crossentropy: 0.2190 - val_loss: 0.2898 - val_accuracy: 0.8840 - val_binary_crossentropy: 0.2898
Epoch 5/20
49/49 - 1s - loss: 0.1813 - accuracy: 0.9408 - binary_crossentropy: 0.1813 - val_loss: 0.2831 - val_accuracy: 0.8875 - val_binary_crossentropy: 0.2831
Epoch 6/20
49/49 - 1s - loss: 0.1567 - accuracy: 0.9493 - binary_crossentropy: 0.1567 - val_loss: 0.2903 - val_accuracy: 0.8857 - val_binary_crossentropy: 0.2903
Epoch 7/20
49/49 - 1s - loss: 0.1382 - accuracy: 0.9564 - binary_crossentropy: 0.1382 - val_loss: 0.3012 - val_accuracy: 0.8827 - val_binary_crossentropy: 0.3012
Epoch 8/20
49/49 - 1s - loss: 0.1235 - accuracy: 0.9615 - binary_crossentropy: 0.1235 - val_loss: 0.3142 - val_accuracy: 0.8794 - val_binary_crossentropy: 0.3142
Epoch 9/20
49/49 - 1s - loss: 0.1107 - accuracy: 0.9665 - binary_crossentropy: 0.1107 - val_loss: 0.3301 - val_accuracy: 0.8775 - val_binary_crossentropy: 0.3301
Epoch 10/20
49/49 - 1s - loss: 0.0996 - accuracy: 0.9708 - binary_crossentropy: 0.0996 - val_loss: 0.3497 - val_accuracy: 0.8737 - val_binary_crossentropy: 0.3497
Epoch 11/20
49/49 - 1s - loss: 0.0907 - accuracy: 0.9743 - binary_crossentropy: 0.0907 - val_loss: 0.3632 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.3632
Epoch 12/20
49/49 - 1s - loss: 0.0811 - accuracy: 0.9788 - binary_crossentropy: 0.0811 - val_loss: 0.3820 - val_accuracy: 0.8701 - val_binary_crossentropy: 0.3820
Epoch 13/20
49/49 - 1s - loss: 0.0731 - accuracy: 0.9820 - binary_crossentropy: 0.0731 - val_loss: 0.4033 - val_accuracy: 0.8680 - val_binary_crossentropy: 0.4033
Epoch 14/20
49/49 - 1s - loss: 0.0665 - accuracy: 0.9840 - binary_crossentropy: 0.0665 - val_loss: 0.4229 - val_accuracy: 0.8649 - val_binary_crossentropy: 0.4229
Epoch 15/20
49/49 - 1s - loss: 0.0601 - accuracy: 0.9866 - binary_crossentropy: 0.0601 - val_loss: 0.4441 - val_accuracy: 0.8647 - val_binary_crossentropy: 0.4441
Epoch 16/20
49/49 - 1s - loss: 0.0541 - accuracy: 0.9891 - binary_crossentropy: 0.0541 - val_loss: 0.4643 - val_accuracy: 0.8632 - val_binary_crossentropy: 0.4643
Epoch 17/20
49/49 - 1s - loss: 0.0484 - accuracy: 0.9907 - binary_crossentropy: 0.0484 - val_loss: 0.4871 - val_accuracy: 0.8619 - val_binary_crossentropy: 0.4871
Epoch 18/20
49/49 - 1s - loss: 0.0438 - accuracy: 0.9926 - binary_crossentropy: 0.0438 - val_loss: 0.5109 - val_accuracy: 0.8602 - val_binary_crossentropy: 0.5109
Epoch 19/20
49/49 - 1s - loss: 0.0392 - accuracy: 0.9938 - binary_crossentropy: 0.0392 - val_loss: 0.5339 - val_accuracy: 0.8592 - val_binary_crossentropy: 0.5339
Epoch 20/20
49/49 - 1s - loss: 0.0353 - accuracy: 0.9949 - binary_crossentropy: 0.0353 - val_loss: 0.5553 - val_accuracy: 0.8588 - val_binary_crossentropy: 0.5553

より大きなモデルの構築

練習として、より大きなモデルを作成し、どれほど急速に過学習が起きるかを見ることもできます。次はこのベンチマークに、この問題が必要とするよりはるかに容量の大きなネットワークを追加しましょう。

bigger_model = keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(512, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

bigger_model.compile(optimizer='adam',
                     loss='binary_crossentropy',
                     metrics=['accuracy','binary_crossentropy'])

bigger_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_6 (Dense)              (None, 512)               5120512   
_________________________________________________________________
dense_7 (Dense)              (None, 512)               262656    
_________________________________________________________________
dense_8 (Dense)              (None, 1)                 513       
=================================================================
Total params: 5,383,681
Trainable params: 5,383,681
Non-trainable params: 0
_________________________________________________________________

このモデルもまた同じデータを使って訓練します。

bigger_history = bigger_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.3516 - accuracy: 0.8427 - binary_crossentropy: 0.3516 - val_loss: 0.2824 - val_accuracy: 0.8848 - val_binary_crossentropy: 0.2824
Epoch 2/20
49/49 - 1s - loss: 0.1353 - accuracy: 0.9509 - binary_crossentropy: 0.1353 - val_loss: 0.3334 - val_accuracy: 0.8749 - val_binary_crossentropy: 0.3334
Epoch 3/20
49/49 - 1s - loss: 0.0400 - accuracy: 0.9890 - binary_crossentropy: 0.0400 - val_loss: 0.4529 - val_accuracy: 0.8689 - val_binary_crossentropy: 0.4529
Epoch 4/20
49/49 - 1s - loss: 0.0055 - accuracy: 0.9991 - binary_crossentropy: 0.0055 - val_loss: 0.6050 - val_accuracy: 0.8695 - val_binary_crossentropy: 0.6050
Epoch 5/20
49/49 - 1s - loss: 0.0011 - accuracy: 0.9999 - binary_crossentropy: 0.0011 - val_loss: 0.6894 - val_accuracy: 0.8695 - val_binary_crossentropy: 0.6894
Epoch 6/20
49/49 - 1s - loss: 2.4456e-04 - accuracy: 1.0000 - binary_crossentropy: 2.4456e-04 - val_loss: 0.7290 - val_accuracy: 0.8713 - val_binary_crossentropy: 0.7290
Epoch 7/20
49/49 - 1s - loss: 1.2892e-04 - accuracy: 1.0000 - binary_crossentropy: 1.2892e-04 - val_loss: 0.7603 - val_accuracy: 0.8726 - val_binary_crossentropy: 0.7603
Epoch 8/20
49/49 - 1s - loss: 8.9720e-05 - accuracy: 1.0000 - binary_crossentropy: 8.9720e-05 - val_loss: 0.7870 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.7870
Epoch 9/20
49/49 - 1s - loss: 6.7607e-05 - accuracy: 1.0000 - binary_crossentropy: 6.7607e-05 - val_loss: 0.8070 - val_accuracy: 0.8723 - val_binary_crossentropy: 0.8070
Epoch 10/20
49/49 - 1s - loss: 5.2862e-05 - accuracy: 1.0000 - binary_crossentropy: 5.2862e-05 - val_loss: 0.8263 - val_accuracy: 0.8723 - val_binary_crossentropy: 0.8263
Epoch 11/20
49/49 - 1s - loss: 4.2534e-05 - accuracy: 1.0000 - binary_crossentropy: 4.2534e-05 - val_loss: 0.8430 - val_accuracy: 0.8726 - val_binary_crossentropy: 0.8430
Epoch 12/20
49/49 - 1s - loss: 3.4808e-05 - accuracy: 1.0000 - binary_crossentropy: 3.4808e-05 - val_loss: 0.8585 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.8585
Epoch 13/20
49/49 - 1s - loss: 2.8948e-05 - accuracy: 1.0000 - binary_crossentropy: 2.8948e-05 - val_loss: 0.8725 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.8725
Epoch 14/20
49/49 - 1s - loss: 2.4367e-05 - accuracy: 1.0000 - binary_crossentropy: 2.4367e-05 - val_loss: 0.8868 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.8868
Epoch 15/20
49/49 - 1s - loss: 2.0663e-05 - accuracy: 1.0000 - binary_crossentropy: 2.0663e-05 - val_loss: 0.8995 - val_accuracy: 0.8723 - val_binary_crossentropy: 0.8995
Epoch 16/20
49/49 - 1s - loss: 1.7690e-05 - accuracy: 1.0000 - binary_crossentropy: 1.7690e-05 - val_loss: 0.9123 - val_accuracy: 0.8723 - val_binary_crossentropy: 0.9123
Epoch 17/20
49/49 - 1s - loss: 1.5249e-05 - accuracy: 1.0000 - binary_crossentropy: 1.5249e-05 - val_loss: 0.9250 - val_accuracy: 0.8721 - val_binary_crossentropy: 0.9250
Epoch 18/20
49/49 - 1s - loss: 1.3218e-05 - accuracy: 1.0000 - binary_crossentropy: 1.3218e-05 - val_loss: 0.9364 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.9364
Epoch 19/20
49/49 - 1s - loss: 1.1524e-05 - accuracy: 1.0000 - binary_crossentropy: 1.1524e-05 - val_loss: 0.9473 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.9473
Epoch 20/20
49/49 - 1s - loss: 1.0092e-05 - accuracy: 1.0000 - binary_crossentropy: 1.0092e-05 - val_loss: 0.9599 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.9599

訓練時と検証時の損失をグラフにする

実線は訓練用データセットの損失、破線は検証用データセットでの損失です(検証用データでの損失が小さい方が良いモデルです)。これをみると、小さいネットワークのほうが比較基準のモデルよりも過学習が始まるのが遅いことがわかります(4エポックではなく6エポック後)。また、過学習が始まっても性能の低下がよりゆっくりしています。

def plot_history(histories, key='binary_crossentropy'):
  plt.figure(figsize=(16,10))

  for name, history in histories:
    val = plt.plot(history.epoch, history.history['val_'+key],
                   '--', label=name.title()+' Val')
    plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
             label=name.title()+' Train')

  plt.xlabel('Epochs')
  plt.ylabel(key.replace('_',' ').title())
  plt.legend()

  plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
              ('smaller', smaller_history),
              ('bigger', bigger_history)])

png

より大きなネットワークでは、すぐに、1エポックで過学習が始まり、その度合も強いことに注目してください。ネットワークの容量が大きいほど訓練用データをモデル化するスピードが早くなり(結果として訓練時の損失値が小さくなり)ますが、より過学習しやすく(結果として訓練時の損失値と検証時の損失値が大きく乖離しやすく)なります。

過学習防止の戦略

重みの正則化を加える

「オッカムの剃刀」の原則をご存知でしょうか。何かの説明が2つあるとすると、最も正しいと考えられる説明は、仮定の数が最も少ない「一番単純な」説明だというものです。この原則は、ニューラルネットワークを使って学習されたモデルにも当てはまります。ある訓練用データとネットワーク構造があって、そのデータを説明できる重みの集合が複数ある時(つまり、複数のモデルがある時)、単純なモデルのほうが複雑なものよりも過学習しにくいのです。

ここで言う「単純なモデル」とは、パラメータ値の分布のエントロピーが小さいもの(あるいは、上記で見たように、そもそもパラメータの数が少ないもの)です。したがって、過学習を緩和するための一般的な手法は、重みが小さい値のみをとることで、重み値の分布がより整然となる(正則)様に制約を与えるものです。これを「重みの正則化」と呼ばれ、ネットワークの損失関数に、重みの大きさに関連するコストを加えることで行われます。このコストには2つの種類があります。

  • L1正則化 重み係数の絶対値に比例するコストを加える(重みの「L1ノルム」と呼ばれる)。

  • L2正則化 重み係数の二乗に比例するコストを加える(重み係数の二乗「L2ノルム」と呼ばれる)。L2正則化はニューラルネットワーク用語では重み減衰(Weight Decay)と呼ばれる。呼び方が違うので混乱しないように。重み減衰は数学的にはL2正則化と同義である。

L1正則化は重みパラメータの一部を0にすることでモデルを疎にする効果があります。L2正則化は重みパラメータにペナルティを加えますがモデルを疎にすることはありません。これは、L2正則化のほうが一般的である理由の一つです。

tf.kerasでは、重みの正則化をするために、重み正則化のインスタンスをキーワード引数として層に加えます。ここでは、L2正則化を追加してみましょう。

l2_model = keras.models.Sequential([
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

l2_model.compile(optimizer='adam',
                 loss='binary_crossentropy',
                 metrics=['accuracy', 'binary_crossentropy'])

l2_model_history = l2_model.fit(train_data, train_labels,
                                epochs=20,
                                batch_size=512,
                                validation_data=(test_data, test_labels),
                                verbose=2)
Epoch 1/20
49/49 - 3s - loss: 0.5289 - accuracy: 0.8047 - binary_crossentropy: 0.4893 - val_loss: 0.3837 - val_accuracy: 0.8756 - val_binary_crossentropy: 0.3419
Epoch 2/20
49/49 - 1s - loss: 0.3051 - accuracy: 0.9083 - binary_crossentropy: 0.2583 - val_loss: 0.3365 - val_accuracy: 0.8869 - val_binary_crossentropy: 0.2859
Epoch 3/20
49/49 - 1s - loss: 0.2537 - accuracy: 0.9305 - binary_crossentropy: 0.2002 - val_loss: 0.3385 - val_accuracy: 0.8869 - val_binary_crossentropy: 0.2831
Epoch 4/20
49/49 - 1s - loss: 0.2305 - accuracy: 0.9404 - binary_crossentropy: 0.1733 - val_loss: 0.3526 - val_accuracy: 0.8822 - val_binary_crossentropy: 0.2940
Epoch 5/20
49/49 - 1s - loss: 0.2170 - accuracy: 0.9460 - binary_crossentropy: 0.1573 - val_loss: 0.3696 - val_accuracy: 0.8779 - val_binary_crossentropy: 0.3089
Epoch 6/20
49/49 - 1s - loss: 0.2063 - accuracy: 0.9504 - binary_crossentropy: 0.1447 - val_loss: 0.3854 - val_accuracy: 0.8737 - val_binary_crossentropy: 0.3233
Epoch 7/20
49/49 - 1s - loss: 0.1981 - accuracy: 0.9544 - binary_crossentropy: 0.1351 - val_loss: 0.4099 - val_accuracy: 0.8688 - val_binary_crossentropy: 0.3462
Epoch 8/20
49/49 - 1s - loss: 0.1947 - accuracy: 0.9551 - binary_crossentropy: 0.1303 - val_loss: 0.4068 - val_accuracy: 0.8705 - val_binary_crossentropy: 0.3420
Epoch 9/20
49/49 - 1s - loss: 0.1878 - accuracy: 0.9574 - binary_crossentropy: 0.1222 - val_loss: 0.4351 - val_accuracy: 0.8651 - val_binary_crossentropy: 0.3691
Epoch 10/20
49/49 - 1s - loss: 0.1856 - accuracy: 0.9575 - binary_crossentropy: 0.1191 - val_loss: 0.4317 - val_accuracy: 0.8682 - val_binary_crossentropy: 0.3647
Epoch 11/20
49/49 - 1s - loss: 0.1827 - accuracy: 0.9604 - binary_crossentropy: 0.1154 - val_loss: 0.4397 - val_accuracy: 0.8658 - val_binary_crossentropy: 0.3719
Epoch 12/20
49/49 - 1s - loss: 0.1763 - accuracy: 0.9630 - binary_crossentropy: 0.1081 - val_loss: 0.4567 - val_accuracy: 0.8628 - val_binary_crossentropy: 0.3884
Epoch 13/20
49/49 - 1s - loss: 0.1673 - accuracy: 0.9677 - binary_crossentropy: 0.0991 - val_loss: 0.4607 - val_accuracy: 0.8634 - val_binary_crossentropy: 0.3929
Epoch 14/20
49/49 - 1s - loss: 0.1697 - accuracy: 0.9649 - binary_crossentropy: 0.1016 - val_loss: 0.4721 - val_accuracy: 0.8635 - val_binary_crossentropy: 0.4030
Epoch 15/20
49/49 - 1s - loss: 0.1682 - accuracy: 0.9657 - binary_crossentropy: 0.0983 - val_loss: 0.4864 - val_accuracy: 0.8592 - val_binary_crossentropy: 0.4162
Epoch 16/20
49/49 - 1s - loss: 0.1578 - accuracy: 0.9719 - binary_crossentropy: 0.0878 - val_loss: 0.4864 - val_accuracy: 0.8607 - val_binary_crossentropy: 0.4170
Epoch 17/20
49/49 - 1s - loss: 0.1524 - accuracy: 0.9739 - binary_crossentropy: 0.0833 - val_loss: 0.4982 - val_accuracy: 0.8572 - val_binary_crossentropy: 0.4293
Epoch 18/20
49/49 - 1s - loss: 0.1525 - accuracy: 0.9734 - binary_crossentropy: 0.0836 - val_loss: 0.5049 - val_accuracy: 0.8597 - val_binary_crossentropy: 0.4356
Epoch 19/20
49/49 - 1s - loss: 0.1493 - accuracy: 0.9744 - binary_crossentropy: 0.0798 - val_loss: 0.5190 - val_accuracy: 0.8560 - val_binary_crossentropy: 0.4494
Epoch 20/20
49/49 - 1s - loss: 0.1495 - accuracy: 0.9750 - binary_crossentropy: 0.0794 - val_loss: 0.5253 - val_accuracy: 0.8570 - val_binary_crossentropy: 0.4546

l2(0.001)というのは、層の重み行列の係数全てに対して0.001 * 重み係数の値 **2をネットワークの損失値合計に加えることを意味します。このペナルティは訓練時のみに加えられるため、このネットワークの損失値は、訓練時にはテスト時に比べて大きくなることに注意してください。

L2正則化の影響を見てみましょう。

plot_history([('baseline', baseline_history),
              ('l2', l2_model_history)])

png

ご覧のように、L2正則化ありのモデルは比較基準のモデルに比べて過学習しにくくなっています。両方のモデルのパラメータ数は同じであるにもかかわらずです。

ドロップアウトを追加する

ドロップアウトは、ニューラルネットワークの正則化テクニックとして最もよく使われる手法の一つです。この手法は、トロント大学のヒントンと彼の学生が開発したものです。ドロップアウトは層に適用するもので、訓練時に層から出力された特徴量に対してランダムに「ドロップアウト(つまりゼロ化)」を行うものです。例えば、ある層が訓練時にある入力サンプルに対して、普通は[0.2, 0.5, 1.3, 0.8, 1.1] というベクトルを出力するとします。ドロップアウトを適用すると、このベクトルは例えば[0, 0.5, 1.3, 0, 1.1]のようにランダムに散らばったいくつかのゼロを含むようになります。「ドロップアウト率」はゼロ化される特徴の割合で、通常は0.2から0.5の間に設定します。テスト時は、どのユニットもドロップアウトされず、代わりに出力値がドロップアウト率と同じ比率でスケールダウンされます。これは、訓練時に比べてたくさんのユニットがアクティブであることに対してバランスをとるためです。

tf.kerasでは、Dropout層を使ってドロップアウトをネットワークに導入できます。ドロップアウト層は、その直前の層の出力に対してドロップアウトを適用します。

それでは、IMDBネットワークに2つのドロップアウト層を追加しましょう。

dpt_model = keras.models.Sequential([
    keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1, activation='sigmoid')
])

dpt_model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy','binary_crossentropy'])

dpt_model_history = dpt_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.6125 - accuracy: 0.6561 - binary_crossentropy: 0.6125 - val_loss: 0.4507 - val_accuracy: 0.8590 - val_binary_crossentropy: 0.4507
Epoch 2/20
49/49 - 1s - loss: 0.4314 - accuracy: 0.8178 - binary_crossentropy: 0.4314 - val_loss: 0.3163 - val_accuracy: 0.8836 - val_binary_crossentropy: 0.3163
Epoch 3/20
49/49 - 1s - loss: 0.3310 - accuracy: 0.8755 - binary_crossentropy: 0.3310 - val_loss: 0.2793 - val_accuracy: 0.8897 - val_binary_crossentropy: 0.2793
Epoch 4/20
49/49 - 1s - loss: 0.2702 - accuracy: 0.9032 - binary_crossentropy: 0.2702 - val_loss: 0.2720 - val_accuracy: 0.8907 - val_binary_crossentropy: 0.2720
Epoch 5/20
49/49 - 1s - loss: 0.2305 - accuracy: 0.9210 - binary_crossentropy: 0.2305 - val_loss: 0.2759 - val_accuracy: 0.8892 - val_binary_crossentropy: 0.2759
Epoch 6/20
49/49 - 1s - loss: 0.1973 - accuracy: 0.9322 - binary_crossentropy: 0.1973 - val_loss: 0.2901 - val_accuracy: 0.8883 - val_binary_crossentropy: 0.2901
Epoch 7/20
49/49 - 1s - loss: 0.1736 - accuracy: 0.9423 - binary_crossentropy: 0.1736 - val_loss: 0.3052 - val_accuracy: 0.8841 - val_binary_crossentropy: 0.3052
Epoch 8/20
49/49 - 1s - loss: 0.1553 - accuracy: 0.9499 - binary_crossentropy: 0.1553 - val_loss: 0.3211 - val_accuracy: 0.8835 - val_binary_crossentropy: 0.3211
Epoch 9/20
49/49 - 1s - loss: 0.1377 - accuracy: 0.9540 - binary_crossentropy: 0.1377 - val_loss: 0.3438 - val_accuracy: 0.8830 - val_binary_crossentropy: 0.3438
Epoch 10/20
49/49 - 1s - loss: 0.1242 - accuracy: 0.9588 - binary_crossentropy: 0.1242 - val_loss: 0.3725 - val_accuracy: 0.8809 - val_binary_crossentropy: 0.3725
Epoch 11/20
49/49 - 1s - loss: 0.1106 - accuracy: 0.9635 - binary_crossentropy: 0.1106 - val_loss: 0.3966 - val_accuracy: 0.8809 - val_binary_crossentropy: 0.3966
Epoch 12/20
49/49 - 1s - loss: 0.0986 - accuracy: 0.9670 - binary_crossentropy: 0.0986 - val_loss: 0.4198 - val_accuracy: 0.8790 - val_binary_crossentropy: 0.4198
Epoch 13/20
49/49 - 1s - loss: 0.0907 - accuracy: 0.9694 - binary_crossentropy: 0.0907 - val_loss: 0.4350 - val_accuracy: 0.8745 - val_binary_crossentropy: 0.4350
Epoch 14/20
49/49 - 1s - loss: 0.0832 - accuracy: 0.9720 - binary_crossentropy: 0.0832 - val_loss: 0.4434 - val_accuracy: 0.8752 - val_binary_crossentropy: 0.4434
Epoch 15/20
49/49 - 1s - loss: 0.0767 - accuracy: 0.9735 - binary_crossentropy: 0.0767 - val_loss: 0.4830 - val_accuracy: 0.8769 - val_binary_crossentropy: 0.4830
Epoch 16/20
49/49 - 1s - loss: 0.0739 - accuracy: 0.9735 - binary_crossentropy: 0.0739 - val_loss: 0.4996 - val_accuracy: 0.8753 - val_binary_crossentropy: 0.4996
Epoch 17/20
49/49 - 1s - loss: 0.0690 - accuracy: 0.9750 - binary_crossentropy: 0.0690 - val_loss: 0.5080 - val_accuracy: 0.8742 - val_binary_crossentropy: 0.5080
Epoch 18/20
49/49 - 1s - loss: 0.0674 - accuracy: 0.9759 - binary_crossentropy: 0.0674 - val_loss: 0.5573 - val_accuracy: 0.8748 - val_binary_crossentropy: 0.5573
Epoch 19/20
49/49 - 1s - loss: 0.0623 - accuracy: 0.9780 - binary_crossentropy: 0.0623 - val_loss: 0.5402 - val_accuracy: 0.8746 - val_binary_crossentropy: 0.5402
Epoch 20/20
49/49 - 1s - loss: 0.0617 - accuracy: 0.9761 - binary_crossentropy: 0.0617 - val_loss: 0.5548 - val_accuracy: 0.8752 - val_binary_crossentropy: 0.5548
plot_history([('baseline', baseline_history),
              ('dropout', dpt_model_history)])

png

ドロップアウトを追加することで、比較対象モデルより明らかに改善が見られます。

まとめ:ニューラルネットワークにおいて過学習を防ぐ最も一般的な方法は次のとおりです。

  • 訓練データを増やす
  • ネットワークの容量をへらす
  • 重みの正則化を行う
  • ドロップアウトを追加する

このガイドで触れていない2つの重要なアプローチがあります。データ拡張とバッチ正規化です。

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.