![]() |
![]() |
![]() |
![]() |
いつものように、この例のプログラムはtf.keras
APIを使用します。詳しくはTensorFlowのKeras guideを参照してください。
これまでの例、つまり、映画レビューの分類と燃費の推定では、検証用データでのモデルの正解率が、数エポックでピークを迎え、その後低下するという現象が見られました。
言い換えると、モデルが訓練用データを過学習したと考えられます。過学習への対処の仕方を学ぶことは重要です。訓練用データセットで高い正解率を達成することは難しくありませんが、我々は、(これまで見たこともない)テスト用データに汎化したモデルを開発したいのです。
過学習の反対語は学習不足(underfitting)です。学習不足は、モデルがテストデータに対してまだ改善の余地がある場合に発生します。学習不足の原因は様々です。モデルが十分強力でないとか、正則化のしすぎだとか、単に訓練時間が短すぎるといった理由があります。学習不足は、訓練用データの中の関連したパターンを学習しきっていないということを意味します。
モデルの訓練をやりすぎると、モデルは過学習を始め、訓練用データの中のパターンで、テストデータには一般的ではないパターンを学習します。我々は、過学習と学習不足の中間を目指す必要があります。これから見ていくように、ちょうどよいエポック数だけ訓練を行うというのは必要なスキルなのです。
過学習を防止するための、最良の解決策は、より多くの訓練用データを使うことです。多くのデータで訓練を行えば行うほど、モデルは自然により汎化していく様になります。これが不可能な場合、次善の策は正則化のようなテクニックを使うことです。正則化は、モデルに保存される情報の量とタイプに制約を課すものです。ネットワークが少数のパターンしか記憶できなければ、最適化プロセスにより、最も主要なパターンのみを学習することになり、より汎化される可能性が高くなります。
このノートブックでは、重みの正則化とドロップアウトという、よく使われる2つの正則化テクニックをご紹介します。これらを使って、IMDBの映画レビューを分類するノートブックの改善を図ります。
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)
2.4.1
IMDBデータセットのダウンロード
以前のノートブックで使用したエンベディングの代わりに、ここでは文をマルチホットエンコードします。このモデルは、訓練用データセットをすぐに過学習します。このモデルを使って、過学習がいつ起きるかということと、どうやって過学習と戦うかをデモします。
リストをマルチホットエンコードすると言うのは、0と1のベクトルにするということです。具体的にいうと、例えば[3, 5]
というシーケンスを、インデックス3と5の値が1で、それ以外がすべて0の、10,000次元のベクトルに変換するということを意味します。
NUM_WORDS = 10000
(train_data, train_labels), (test_data, test_labels) = keras.datasets.imdb.load_data(num_words=NUM_WORDS)
def multi_hot_sequences(sequences, dimension):
# 形状が (len(sequences), dimension)ですべて0の行列を作る
results = np.zeros((len(sequences), dimension))
for i, word_indices in enumerate(sequences):
results[i, word_indices] = 1.0 # 特定のインデックスに対してresults[i] を1に設定する
return results
train_data = multi_hot_sequences(train_data, dimension=NUM_WORDS)
test_data = multi_hot_sequences(test_data, dimension=NUM_WORDS)
<string>:6: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/datasets/imdb.py:159: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx]) /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/datasets/imdb.py:160: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
結果として得られるマルチホットベクトルの1つを見てみましょう。単語のインデックスは頻度順にソートされています。このため、インデックスが0に近いほど1が多く出現するはずです。分布を見てみましょう。
plt.plot(train_data[0])
[<matplotlib.lines.Line2D at 0x7f0986e77f98>]
過学習のデモ
過学習を防止するための最も単純な方法は、モデルのサイズ、すなわち、モデル内の学習可能なパラメータの数を小さくすることです(学習パラメータの数は、層の数と層ごとのユニット数で決まります)。ディープラーニングでは、モデルの学習可能なパラメータ数を、しばしばモデルの「キャパシティ」と呼びます。直感的に考えれば、パラメータ数の多いモデルほど「記憶容量」が大きくなり、訓練用のサンプルとその目的変数の間の辞書のようなマッピングをたやすく学習することができます。このマッピングには汎化能力がまったくなく、これまで見たことが無いデータを使って予測をする際には役に立ちません。
ディープラーニングのモデルは訓練用データに適応しやすいけれど、本当のチャレレンジは汎化であって適応ではないということを、肝に銘じておく必要があります。
一方、ネットワークの記憶容量が限られている場合、前述のようなマッピングを簡単に学習することはできません。損失を減らすためには、より予測能力が高い圧縮された表現を学習しなければなりません。同時に、モデルを小さくしすぎると、訓練用データに適応するのが難しくなります。「多すぎる容量」と「容量不足」の間にちょうどよい容量があるのです。
残念ながら、(層の数や、層ごとの大きさといった)モデルの適切なサイズやアーキテクチャを決める魔法の方程式はありません。一連の異なるアーキテクチャを使って実験を行う必要があります。
適切なモデルのサイズを見つけるには、比較的少ない層の数とパラメータから始めるのがベストです。それから、検証用データでの損失値の改善が見られなくなるまで、徐々に層の大きさを増やしたり、新たな層を加えたりします。映画レビューの分類ネットワークでこれを試してみましょう。
比較基準として、Dense
層だけを使ったシンプルなモデルを構築し、その後、それより小さいバージョンと大きいバージョンを作って比較します。
比較基準を作る
baseline_model = keras.Sequential([
# `.summary` を見るために`input_shape`が必要
keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
keras.layers.Dense(16, activation='relu'),
keras.layers.Dense(1, activation='sigmoid')
])
baseline_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 160016 _________________________________________________________________ dense_1 (Dense) (None, 16) 272 _________________________________________________________________ dense_2 (Dense) (None, 1) 17 ================================================================= Total params: 160,305 Trainable params: 160,305 Non-trainable params: 0 _________________________________________________________________
baseline_history = baseline_model.fit(train_data,
train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
Epoch 1/20 49/49 - 3s - loss: 0.5014 - accuracy: 0.7904 - binary_crossentropy: 0.5014 - val_loss: 0.3434 - val_accuracy: 0.8751 - val_binary_crossentropy: 0.3434 Epoch 2/20 49/49 - 1s - loss: 0.2533 - accuracy: 0.9106 - binary_crossentropy: 0.2533 - val_loss: 0.2832 - val_accuracy: 0.8877 - val_binary_crossentropy: 0.2832 Epoch 3/20 49/49 - 1s - loss: 0.1815 - accuracy: 0.9362 - binary_crossentropy: 0.1815 - val_loss: 0.2885 - val_accuracy: 0.8864 - val_binary_crossentropy: 0.2885 Epoch 4/20 49/49 - 1s - loss: 0.1420 - accuracy: 0.9520 - binary_crossentropy: 0.1420 - val_loss: 0.3159 - val_accuracy: 0.8772 - val_binary_crossentropy: 0.3159 Epoch 5/20 49/49 - 1s - loss: 0.1141 - accuracy: 0.9634 - binary_crossentropy: 0.1141 - val_loss: 0.3404 - val_accuracy: 0.8749 - val_binary_crossentropy: 0.3404 Epoch 6/20 49/49 - 1s - loss: 0.0922 - accuracy: 0.9711 - binary_crossentropy: 0.0922 - val_loss: 0.3755 - val_accuracy: 0.8692 - val_binary_crossentropy: 0.3755 Epoch 7/20 49/49 - 1s - loss: 0.0729 - accuracy: 0.9800 - binary_crossentropy: 0.0729 - val_loss: 0.4150 - val_accuracy: 0.8656 - val_binary_crossentropy: 0.4150 Epoch 8/20 49/49 - 1s - loss: 0.0562 - accuracy: 0.9867 - binary_crossentropy: 0.0562 - val_loss: 0.4619 - val_accuracy: 0.8627 - val_binary_crossentropy: 0.4619 Epoch 9/20 49/49 - 1s - loss: 0.0440 - accuracy: 0.9911 - binary_crossentropy: 0.0440 - val_loss: 0.5012 - val_accuracy: 0.8588 - val_binary_crossentropy: 0.5012 Epoch 10/20 49/49 - 1s - loss: 0.0329 - accuracy: 0.9944 - binary_crossentropy: 0.0329 - val_loss: 0.5493 - val_accuracy: 0.8589 - val_binary_crossentropy: 0.5493 Epoch 11/20 49/49 - 1s - loss: 0.0240 - accuracy: 0.9968 - binary_crossentropy: 0.0240 - val_loss: 0.5859 - val_accuracy: 0.8568 - val_binary_crossentropy: 0.5859 Epoch 12/20 49/49 - 1s - loss: 0.0171 - accuracy: 0.9987 - binary_crossentropy: 0.0171 - val_loss: 0.6230 - val_accuracy: 0.8561 - val_binary_crossentropy: 0.6230 Epoch 13/20 49/49 - 1s - loss: 0.0126 - accuracy: 0.9990 - binary_crossentropy: 0.0126 - val_loss: 0.6698 - val_accuracy: 0.8544 - val_binary_crossentropy: 0.6698 Epoch 14/20 49/49 - 1s - loss: 0.0090 - accuracy: 0.9996 - binary_crossentropy: 0.0090 - val_loss: 0.6986 - val_accuracy: 0.8534 - val_binary_crossentropy: 0.6986 Epoch 15/20 49/49 - 1s - loss: 0.0068 - accuracy: 0.9998 - binary_crossentropy: 0.0068 - val_loss: 0.7289 - val_accuracy: 0.8534 - val_binary_crossentropy: 0.7289 Epoch 16/20 49/49 - 1s - loss: 0.0052 - accuracy: 0.9999 - binary_crossentropy: 0.0052 - val_loss: 0.7601 - val_accuracy: 0.8535 - val_binary_crossentropy: 0.7601 Epoch 17/20 49/49 - 1s - loss: 0.0041 - accuracy: 0.9999 - binary_crossentropy: 0.0041 - val_loss: 0.7884 - val_accuracy: 0.8531 - val_binary_crossentropy: 0.7884 Epoch 18/20 49/49 - 1s - loss: 0.0033 - accuracy: 1.0000 - binary_crossentropy: 0.0033 - val_loss: 0.8111 - val_accuracy: 0.8538 - val_binary_crossentropy: 0.8111 Epoch 19/20 49/49 - 1s - loss: 0.0027 - accuracy: 1.0000 - binary_crossentropy: 0.0027 - val_loss: 0.8327 - val_accuracy: 0.8528 - val_binary_crossentropy: 0.8327 Epoch 20/20 49/49 - 1s - loss: 0.0023 - accuracy: 1.0000 - binary_crossentropy: 0.0023 - val_loss: 0.8519 - val_accuracy: 0.8529 - val_binary_crossentropy: 0.8519
より小さいモデルの構築
今作成したばかりの比較基準となるモデルに比べて隠れユニット数が少ないモデルを作りましょう。
smaller_model = keras.Sequential([
keras.layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
keras.layers.Dense(4, activation='relu'),
keras.layers.Dense(1, activation='sigmoid')
])
smaller_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
smaller_model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_3 (Dense) (None, 4) 40004 _________________________________________________________________ dense_4 (Dense) (None, 4) 20 _________________________________________________________________ dense_5 (Dense) (None, 1) 5 ================================================================= Total params: 40,029 Trainable params: 40,029 Non-trainable params: 0 _________________________________________________________________
同じデータを使って訓練します。
smaller_history = smaller_model.fit(train_data,
train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
Epoch 1/20 49/49 - 3s - loss: 0.6453 - accuracy: 0.6029 - binary_crossentropy: 0.6453 - val_loss: 0.5700 - val_accuracy: 0.7488 - val_binary_crossentropy: 0.5700 Epoch 2/20 49/49 - 1s - loss: 0.4972 - accuracy: 0.8303 - binary_crossentropy: 0.4972 - val_loss: 0.4315 - val_accuracy: 0.8690 - val_binary_crossentropy: 0.4315 Epoch 3/20 49/49 - 1s - loss: 0.3058 - accuracy: 0.9089 - binary_crossentropy: 0.3058 - val_loss: 0.3050 - val_accuracy: 0.8854 - val_binary_crossentropy: 0.3050 Epoch 4/20 49/49 - 1s - loss: 0.2190 - accuracy: 0.9283 - binary_crossentropy: 0.2190 - val_loss: 0.2898 - val_accuracy: 0.8840 - val_binary_crossentropy: 0.2898 Epoch 5/20 49/49 - 1s - loss: 0.1813 - accuracy: 0.9408 - binary_crossentropy: 0.1813 - val_loss: 0.2831 - val_accuracy: 0.8875 - val_binary_crossentropy: 0.2831 Epoch 6/20 49/49 - 1s - loss: 0.1567 - accuracy: 0.9493 - binary_crossentropy: 0.1567 - val_loss: 0.2903 - val_accuracy: 0.8857 - val_binary_crossentropy: 0.2903 Epoch 7/20 49/49 - 1s - loss: 0.1382 - accuracy: 0.9564 - binary_crossentropy: 0.1382 - val_loss: 0.3012 - val_accuracy: 0.8827 - val_binary_crossentropy: 0.3012 Epoch 8/20 49/49 - 1s - loss: 0.1235 - accuracy: 0.9615 - binary_crossentropy: 0.1235 - val_loss: 0.3142 - val_accuracy: 0.8794 - val_binary_crossentropy: 0.3142 Epoch 9/20 49/49 - 1s - loss: 0.1107 - accuracy: 0.9665 - binary_crossentropy: 0.1107 - val_loss: 0.3301 - val_accuracy: 0.8775 - val_binary_crossentropy: 0.3301 Epoch 10/20 49/49 - 1s - loss: 0.0996 - accuracy: 0.9708 - binary_crossentropy: 0.0996 - val_loss: 0.3497 - val_accuracy: 0.8737 - val_binary_crossentropy: 0.3497 Epoch 11/20 49/49 - 1s - loss: 0.0907 - accuracy: 0.9743 - binary_crossentropy: 0.0907 - val_loss: 0.3632 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.3632 Epoch 12/20 49/49 - 1s - loss: 0.0811 - accuracy: 0.9788 - binary_crossentropy: 0.0811 - val_loss: 0.3820 - val_accuracy: 0.8701 - val_binary_crossentropy: 0.3820 Epoch 13/20 49/49 - 1s - loss: 0.0731 - accuracy: 0.9820 - binary_crossentropy: 0.0731 - val_loss: 0.4033 - val_accuracy: 0.8680 - val_binary_crossentropy: 0.4033 Epoch 14/20 49/49 - 1s - loss: 0.0665 - accuracy: 0.9840 - binary_crossentropy: 0.0665 - val_loss: 0.4229 - val_accuracy: 0.8649 - val_binary_crossentropy: 0.4229 Epoch 15/20 49/49 - 1s - loss: 0.0601 - accuracy: 0.9866 - binary_crossentropy: 0.0601 - val_loss: 0.4441 - val_accuracy: 0.8647 - val_binary_crossentropy: 0.4441 Epoch 16/20 49/49 - 1s - loss: 0.0541 - accuracy: 0.9891 - binary_crossentropy: 0.0541 - val_loss: 0.4643 - val_accuracy: 0.8632 - val_binary_crossentropy: 0.4643 Epoch 17/20 49/49 - 1s - loss: 0.0484 - accuracy: 0.9907 - binary_crossentropy: 0.0484 - val_loss: 0.4871 - val_accuracy: 0.8619 - val_binary_crossentropy: 0.4871 Epoch 18/20 49/49 - 1s - loss: 0.0438 - accuracy: 0.9926 - binary_crossentropy: 0.0438 - val_loss: 0.5109 - val_accuracy: 0.8602 - val_binary_crossentropy: 0.5109 Epoch 19/20 49/49 - 1s - loss: 0.0392 - accuracy: 0.9938 - binary_crossentropy: 0.0392 - val_loss: 0.5339 - val_accuracy: 0.8592 - val_binary_crossentropy: 0.5339 Epoch 20/20 49/49 - 1s - loss: 0.0353 - accuracy: 0.9949 - binary_crossentropy: 0.0353 - val_loss: 0.5553 - val_accuracy: 0.8588 - val_binary_crossentropy: 0.5553
より大きなモデルの構築
練習として、より大きなモデルを作成し、どれほど急速に過学習が起きるかを見ることもできます。次はこのベンチマークに、この問題が必要とするよりはるかに容量の大きなネットワークを追加しましょう。
bigger_model = keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
keras.layers.Dense(512, activation='relu'),
keras.layers.Dense(1, activation='sigmoid')
])
bigger_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy'])
bigger_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_6 (Dense) (None, 512) 5120512 _________________________________________________________________ dense_7 (Dense) (None, 512) 262656 _________________________________________________________________ dense_8 (Dense) (None, 1) 513 ================================================================= Total params: 5,383,681 Trainable params: 5,383,681 Non-trainable params: 0 _________________________________________________________________
このモデルもまた同じデータを使って訓練します。
bigger_history = bigger_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
Epoch 1/20 49/49 - 2s - loss: 0.3516 - accuracy: 0.8427 - binary_crossentropy: 0.3516 - val_loss: 0.2824 - val_accuracy: 0.8848 - val_binary_crossentropy: 0.2824 Epoch 2/20 49/49 - 1s - loss: 0.1353 - accuracy: 0.9509 - binary_crossentropy: 0.1353 - val_loss: 0.3334 - val_accuracy: 0.8749 - val_binary_crossentropy: 0.3334 Epoch 3/20 49/49 - 1s - loss: 0.0400 - accuracy: 0.9890 - binary_crossentropy: 0.0400 - val_loss: 0.4529 - val_accuracy: 0.8689 - val_binary_crossentropy: 0.4529 Epoch 4/20 49/49 - 1s - loss: 0.0055 - accuracy: 0.9991 - binary_crossentropy: 0.0055 - val_loss: 0.6050 - val_accuracy: 0.8695 - val_binary_crossentropy: 0.6050 Epoch 5/20 49/49 - 1s - loss: 0.0011 - accuracy: 0.9999 - binary_crossentropy: 0.0011 - val_loss: 0.6894 - val_accuracy: 0.8695 - val_binary_crossentropy: 0.6894 Epoch 6/20 49/49 - 1s - loss: 2.4456e-04 - accuracy: 1.0000 - binary_crossentropy: 2.4456e-04 - val_loss: 0.7290 - val_accuracy: 0.8713 - val_binary_crossentropy: 0.7290 Epoch 7/20 49/49 - 1s - loss: 1.2892e-04 - accuracy: 1.0000 - binary_crossentropy: 1.2892e-04 - val_loss: 0.7603 - val_accuracy: 0.8726 - val_binary_crossentropy: 0.7603 Epoch 8/20 49/49 - 1s - loss: 8.9720e-05 - accuracy: 1.0000 - binary_crossentropy: 8.9720e-05 - val_loss: 0.7870 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.7870 Epoch 9/20 49/49 - 1s - loss: 6.7607e-05 - accuracy: 1.0000 - binary_crossentropy: 6.7607e-05 - val_loss: 0.8070 - val_accuracy: 0.8723 - val_binary_crossentropy: 0.8070 Epoch 10/20 49/49 - 1s - loss: 5.2862e-05 - accuracy: 1.0000 - binary_crossentropy: 5.2862e-05 - val_loss: 0.8263 - val_accuracy: 0.8723 - val_binary_crossentropy: 0.8263 Epoch 11/20 49/49 - 1s - loss: 4.2534e-05 - accuracy: 1.0000 - binary_crossentropy: 4.2534e-05 - val_loss: 0.8430 - val_accuracy: 0.8726 - val_binary_crossentropy: 0.8430 Epoch 12/20 49/49 - 1s - loss: 3.4808e-05 - accuracy: 1.0000 - binary_crossentropy: 3.4808e-05 - val_loss: 0.8585 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.8585 Epoch 13/20 49/49 - 1s - loss: 2.8948e-05 - accuracy: 1.0000 - binary_crossentropy: 2.8948e-05 - val_loss: 0.8725 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.8725 Epoch 14/20 49/49 - 1s - loss: 2.4367e-05 - accuracy: 1.0000 - binary_crossentropy: 2.4367e-05 - val_loss: 0.8868 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.8868 Epoch 15/20 49/49 - 1s - loss: 2.0663e-05 - accuracy: 1.0000 - binary_crossentropy: 2.0663e-05 - val_loss: 0.8995 - val_accuracy: 0.8723 - val_binary_crossentropy: 0.8995 Epoch 16/20 49/49 - 1s - loss: 1.7690e-05 - accuracy: 1.0000 - binary_crossentropy: 1.7690e-05 - val_loss: 0.9123 - val_accuracy: 0.8723 - val_binary_crossentropy: 0.9123 Epoch 17/20 49/49 - 1s - loss: 1.5249e-05 - accuracy: 1.0000 - binary_crossentropy: 1.5249e-05 - val_loss: 0.9250 - val_accuracy: 0.8721 - val_binary_crossentropy: 0.9250 Epoch 18/20 49/49 - 1s - loss: 1.3218e-05 - accuracy: 1.0000 - binary_crossentropy: 1.3218e-05 - val_loss: 0.9364 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.9364 Epoch 19/20 49/49 - 1s - loss: 1.1524e-05 - accuracy: 1.0000 - binary_crossentropy: 1.1524e-05 - val_loss: 0.9473 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.9473 Epoch 20/20 49/49 - 1s - loss: 1.0092e-05 - accuracy: 1.0000 - binary_crossentropy: 1.0092e-05 - val_loss: 0.9599 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.9599
訓練時と検証時の損失をグラフにする
実線は訓練用データセットの損失、破線は検証用データセットでの損失です(検証用データでの損失が小さい方が良いモデルです)。これをみると、小さいネットワークのほうが比較基準のモデルよりも過学習が始まるのが遅いことがわかります(4エポックではなく6エポック後)。また、過学習が始まっても性能の低下がよりゆっくりしています。
def plot_history(histories, key='binary_crossentropy'):
plt.figure(figsize=(16,10))
for name, history in histories:
val = plt.plot(history.epoch, history.history['val_'+key],
'--', label=name.title()+' Val')
plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
label=name.title()+' Train')
plt.xlabel('Epochs')
plt.ylabel(key.replace('_',' ').title())
plt.legend()
plt.xlim([0,max(history.epoch)])
plot_history([('baseline', baseline_history),
('smaller', smaller_history),
('bigger', bigger_history)])
より大きなネットワークでは、すぐに、1エポックで過学習が始まり、その度合も強いことに注目してください。ネットワークの容量が大きいほど訓練用データをモデル化するスピードが早くなり(結果として訓練時の損失値が小さくなり)ますが、より過学習しやすく(結果として訓練時の損失値と検証時の損失値が大きく乖離しやすく)なります。
過学習防止の戦略
重みの正則化を加える
「オッカムの剃刀」の原則をご存知でしょうか。何かの説明が2つあるとすると、最も正しいと考えられる説明は、仮定の数が最も少ない「一番単純な」説明だというものです。この原則は、ニューラルネットワークを使って学習されたモデルにも当てはまります。ある訓練用データとネットワーク構造があって、そのデータを説明できる重みの集合が複数ある時(つまり、複数のモデルがある時)、単純なモデルのほうが複雑なものよりも過学習しにくいのです。
ここで言う「単純なモデル」とは、パラメータ値の分布のエントロピーが小さいもの(あるいは、上記で見たように、そもそもパラメータの数が少ないもの)です。したがって、過学習を緩和するための一般的な手法は、重みが小さい値のみをとることで、重み値の分布がより整然となる(正則)様に制約を与えるものです。これを「重みの正則化」と呼ばれ、ネットワークの損失関数に、重みの大きさに関連するコストを加えることで行われます。このコストには2つの種類があります。
L1正則化 重み係数の絶対値に比例するコストを加える(重みの「L1ノルム」と呼ばれる)。
L2正則化 重み係数の二乗に比例するコストを加える(重み係数の二乗「L2ノルム」と呼ばれる)。L2正則化はニューラルネットワーク用語では重み減衰(Weight Decay)と呼ばれる。呼び方が違うので混乱しないように。重み減衰は数学的にはL2正則化と同義である。
L1正則化は重みパラメータの一部を0にすることでモデルを疎にする効果があります。L2正則化は重みパラメータにペナルティを加えますがモデルを疎にすることはありません。これは、L2正則化のほうが一般的である理由の一つです。
tf.keras
では、重みの正則化をするために、重み正則化のインスタンスをキーワード引数として層に加えます。ここでは、L2正則化を追加してみましょう。
l2_model = keras.models.Sequential([
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
activation='relu', input_shape=(NUM_WORDS,)),
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
activation='relu'),
keras.layers.Dense(1, activation='sigmoid')
])
l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
l2_model_history = l2_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
Epoch 1/20 49/49 - 3s - loss: 0.5289 - accuracy: 0.8047 - binary_crossentropy: 0.4893 - val_loss: 0.3837 - val_accuracy: 0.8756 - val_binary_crossentropy: 0.3419 Epoch 2/20 49/49 - 1s - loss: 0.3051 - accuracy: 0.9083 - binary_crossentropy: 0.2583 - val_loss: 0.3365 - val_accuracy: 0.8869 - val_binary_crossentropy: 0.2859 Epoch 3/20 49/49 - 1s - loss: 0.2537 - accuracy: 0.9305 - binary_crossentropy: 0.2002 - val_loss: 0.3385 - val_accuracy: 0.8869 - val_binary_crossentropy: 0.2831 Epoch 4/20 49/49 - 1s - loss: 0.2305 - accuracy: 0.9404 - binary_crossentropy: 0.1733 - val_loss: 0.3526 - val_accuracy: 0.8822 - val_binary_crossentropy: 0.2940 Epoch 5/20 49/49 - 1s - loss: 0.2170 - accuracy: 0.9460 - binary_crossentropy: 0.1573 - val_loss: 0.3696 - val_accuracy: 0.8779 - val_binary_crossentropy: 0.3089 Epoch 6/20 49/49 - 1s - loss: 0.2063 - accuracy: 0.9504 - binary_crossentropy: 0.1447 - val_loss: 0.3854 - val_accuracy: 0.8737 - val_binary_crossentropy: 0.3233 Epoch 7/20 49/49 - 1s - loss: 0.1981 - accuracy: 0.9544 - binary_crossentropy: 0.1351 - val_loss: 0.4099 - val_accuracy: 0.8688 - val_binary_crossentropy: 0.3462 Epoch 8/20 49/49 - 1s - loss: 0.1947 - accuracy: 0.9551 - binary_crossentropy: 0.1303 - val_loss: 0.4068 - val_accuracy: 0.8705 - val_binary_crossentropy: 0.3420 Epoch 9/20 49/49 - 1s - loss: 0.1878 - accuracy: 0.9574 - binary_crossentropy: 0.1222 - val_loss: 0.4351 - val_accuracy: 0.8651 - val_binary_crossentropy: 0.3691 Epoch 10/20 49/49 - 1s - loss: 0.1856 - accuracy: 0.9575 - binary_crossentropy: 0.1191 - val_loss: 0.4317 - val_accuracy: 0.8682 - val_binary_crossentropy: 0.3647 Epoch 11/20 49/49 - 1s - loss: 0.1827 - accuracy: 0.9604 - binary_crossentropy: 0.1154 - val_loss: 0.4397 - val_accuracy: 0.8658 - val_binary_crossentropy: 0.3719 Epoch 12/20 49/49 - 1s - loss: 0.1763 - accuracy: 0.9630 - binary_crossentropy: 0.1081 - val_loss: 0.4567 - val_accuracy: 0.8628 - val_binary_crossentropy: 0.3884 Epoch 13/20 49/49 - 1s - loss: 0.1673 - accuracy: 0.9677 - binary_crossentropy: 0.0991 - val_loss: 0.4607 - val_accuracy: 0.8634 - val_binary_crossentropy: 0.3929 Epoch 14/20 49/49 - 1s - loss: 0.1697 - accuracy: 0.9649 - binary_crossentropy: 0.1016 - val_loss: 0.4721 - val_accuracy: 0.8635 - val_binary_crossentropy: 0.4030 Epoch 15/20 49/49 - 1s - loss: 0.1682 - accuracy: 0.9657 - binary_crossentropy: 0.0983 - val_loss: 0.4864 - val_accuracy: 0.8592 - val_binary_crossentropy: 0.4162 Epoch 16/20 49/49 - 1s - loss: 0.1578 - accuracy: 0.9719 - binary_crossentropy: 0.0878 - val_loss: 0.4864 - val_accuracy: 0.8607 - val_binary_crossentropy: 0.4170 Epoch 17/20 49/49 - 1s - loss: 0.1524 - accuracy: 0.9739 - binary_crossentropy: 0.0833 - val_loss: 0.4982 - val_accuracy: 0.8572 - val_binary_crossentropy: 0.4293 Epoch 18/20 49/49 - 1s - loss: 0.1525 - accuracy: 0.9734 - binary_crossentropy: 0.0836 - val_loss: 0.5049 - val_accuracy: 0.8597 - val_binary_crossentropy: 0.4356 Epoch 19/20 49/49 - 1s - loss: 0.1493 - accuracy: 0.9744 - binary_crossentropy: 0.0798 - val_loss: 0.5190 - val_accuracy: 0.8560 - val_binary_crossentropy: 0.4494 Epoch 20/20 49/49 - 1s - loss: 0.1495 - accuracy: 0.9750 - binary_crossentropy: 0.0794 - val_loss: 0.5253 - val_accuracy: 0.8570 - val_binary_crossentropy: 0.4546
l2(0.001)
というのは、層の重み行列の係数全てに対して0.001 * 重み係数の値 **2
をネットワークの損失値合計に加えることを意味します。このペナルティは訓練時のみに加えられるため、このネットワークの損失値は、訓練時にはテスト時に比べて大きくなることに注意してください。
L2正則化の影響を見てみましょう。
plot_history([('baseline', baseline_history),
('l2', l2_model_history)])
ご覧のように、L2正則化ありのモデルは比較基準のモデルに比べて過学習しにくくなっています。両方のモデルのパラメータ数は同じであるにもかかわらずです。
ドロップアウトを追加する
ドロップアウトは、ニューラルネットワークの正則化テクニックとして最もよく使われる手法の一つです。この手法は、トロント大学のヒントンと彼の学生が開発したものです。ドロップアウトは層に適用するもので、訓練時に層から出力された特徴量に対してランダムに「ドロップアウト(つまりゼロ化)」を行うものです。例えば、ある層が訓練時にある入力サンプルに対して、普通は[0.2, 0.5, 1.3, 0.8, 1.1]
というベクトルを出力するとします。ドロップアウトを適用すると、このベクトルは例えば[0, 0.5, 1.3, 0, 1.1]
のようにランダムに散らばったいくつかのゼロを含むようになります。「ドロップアウト率」はゼロ化される特徴の割合で、通常は0.2から0.5の間に設定します。テスト時は、どのユニットもドロップアウトされず、代わりに出力値がドロップアウト率と同じ比率でスケールダウンされます。これは、訓練時に比べてたくさんのユニットがアクティブであることに対してバランスをとるためです。
tf.keras
では、Dropout層を使ってドロップアウトをネットワークに導入できます。ドロップアウト層は、その直前の層の出力に対してドロップアウトを適用します。
それでは、IMDBネットワークに2つのドロップアウト層を追加しましょう。
dpt_model = keras.models.Sequential([
keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
keras.layers.Dropout(0.5),
keras.layers.Dense(16, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid')
])
dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy'])
dpt_model_history = dpt_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
Epoch 1/20 49/49 - 2s - loss: 0.6125 - accuracy: 0.6561 - binary_crossentropy: 0.6125 - val_loss: 0.4507 - val_accuracy: 0.8590 - val_binary_crossentropy: 0.4507 Epoch 2/20 49/49 - 1s - loss: 0.4314 - accuracy: 0.8178 - binary_crossentropy: 0.4314 - val_loss: 0.3163 - val_accuracy: 0.8836 - val_binary_crossentropy: 0.3163 Epoch 3/20 49/49 - 1s - loss: 0.3310 - accuracy: 0.8755 - binary_crossentropy: 0.3310 - val_loss: 0.2793 - val_accuracy: 0.8897 - val_binary_crossentropy: 0.2793 Epoch 4/20 49/49 - 1s - loss: 0.2702 - accuracy: 0.9032 - binary_crossentropy: 0.2702 - val_loss: 0.2720 - val_accuracy: 0.8907 - val_binary_crossentropy: 0.2720 Epoch 5/20 49/49 - 1s - loss: 0.2305 - accuracy: 0.9210 - binary_crossentropy: 0.2305 - val_loss: 0.2759 - val_accuracy: 0.8892 - val_binary_crossentropy: 0.2759 Epoch 6/20 49/49 - 1s - loss: 0.1973 - accuracy: 0.9322 - binary_crossentropy: 0.1973 - val_loss: 0.2901 - val_accuracy: 0.8883 - val_binary_crossentropy: 0.2901 Epoch 7/20 49/49 - 1s - loss: 0.1736 - accuracy: 0.9423 - binary_crossentropy: 0.1736 - val_loss: 0.3052 - val_accuracy: 0.8841 - val_binary_crossentropy: 0.3052 Epoch 8/20 49/49 - 1s - loss: 0.1553 - accuracy: 0.9499 - binary_crossentropy: 0.1553 - val_loss: 0.3211 - val_accuracy: 0.8835 - val_binary_crossentropy: 0.3211 Epoch 9/20 49/49 - 1s - loss: 0.1377 - accuracy: 0.9540 - binary_crossentropy: 0.1377 - val_loss: 0.3438 - val_accuracy: 0.8830 - val_binary_crossentropy: 0.3438 Epoch 10/20 49/49 - 1s - loss: 0.1242 - accuracy: 0.9588 - binary_crossentropy: 0.1242 - val_loss: 0.3725 - val_accuracy: 0.8809 - val_binary_crossentropy: 0.3725 Epoch 11/20 49/49 - 1s - loss: 0.1106 - accuracy: 0.9635 - binary_crossentropy: 0.1106 - val_loss: 0.3966 - val_accuracy: 0.8809 - val_binary_crossentropy: 0.3966 Epoch 12/20 49/49 - 1s - loss: 0.0986 - accuracy: 0.9670 - binary_crossentropy: 0.0986 - val_loss: 0.4198 - val_accuracy: 0.8790 - val_binary_crossentropy: 0.4198 Epoch 13/20 49/49 - 1s - loss: 0.0907 - accuracy: 0.9694 - binary_crossentropy: 0.0907 - val_loss: 0.4350 - val_accuracy: 0.8745 - val_binary_crossentropy: 0.4350 Epoch 14/20 49/49 - 1s - loss: 0.0832 - accuracy: 0.9720 - binary_crossentropy: 0.0832 - val_loss: 0.4434 - val_accuracy: 0.8752 - val_binary_crossentropy: 0.4434 Epoch 15/20 49/49 - 1s - loss: 0.0767 - accuracy: 0.9735 - binary_crossentropy: 0.0767 - val_loss: 0.4830 - val_accuracy: 0.8769 - val_binary_crossentropy: 0.4830 Epoch 16/20 49/49 - 1s - loss: 0.0739 - accuracy: 0.9735 - binary_crossentropy: 0.0739 - val_loss: 0.4996 - val_accuracy: 0.8753 - val_binary_crossentropy: 0.4996 Epoch 17/20 49/49 - 1s - loss: 0.0690 - accuracy: 0.9750 - binary_crossentropy: 0.0690 - val_loss: 0.5080 - val_accuracy: 0.8742 - val_binary_crossentropy: 0.5080 Epoch 18/20 49/49 - 1s - loss: 0.0674 - accuracy: 0.9759 - binary_crossentropy: 0.0674 - val_loss: 0.5573 - val_accuracy: 0.8748 - val_binary_crossentropy: 0.5573 Epoch 19/20 49/49 - 1s - loss: 0.0623 - accuracy: 0.9780 - binary_crossentropy: 0.0623 - val_loss: 0.5402 - val_accuracy: 0.8746 - val_binary_crossentropy: 0.5402 Epoch 20/20 49/49 - 1s - loss: 0.0617 - accuracy: 0.9761 - binary_crossentropy: 0.0617 - val_loss: 0.5548 - val_accuracy: 0.8752 - val_binary_crossentropy: 0.5548
plot_history([('baseline', baseline_history),
('dropout', dpt_model_history)])
ドロップアウトを追加することで、比較対象モデルより明らかに改善が見られます。
まとめ:ニューラルネットワークにおいて過学習を防ぐ最も一般的な方法は次のとおりです。
- 訓練データを増やす
- ネットワークの容量をへらす
- 重みの正則化を行う
- ドロップアウトを追加する
このガイドで触れていない2つの重要なアプローチがあります。データ拡張とバッチ正規化です。
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.