過学習と学習不足について知る

View on TensorFlow.org Run in Google Colab View source on GitHub

いつものように、この例のプログラムはtf.keras APIを使用します。詳しくはTensorFlowのKeras guideを参照してください。

これまでの例、つまり、映画レビューの分類と燃費の推定では、検証用データでのモデルの正解率が、数エポックでピークを迎え、その後低下するという現象が見られました。

言い換えると、モデルが訓練用データを過学習したと考えられます。過学習への対処の仕方を学ぶことは重要です。訓練用データセットで高い正解率を達成することは難しくありませんが、我々は、(これまで見たこともない)テスト用データに汎化したモデルを開発したいのです。

過学習の反対語は学習不足(underfitting)です。学習不足は、モデルがテストデータに対してまだ改善の余地がある場合に発生します。学習不足の原因は様々です。モデルが十分強力でないとか、正則化のしすぎだとか、単に訓練時間が短すぎるといった理由があります。学習不足は、訓練用データの中の関連したパターンを学習しきっていないということを意味します。

モデルの訓練をやりすぎると、モデルは過学習を始め、訓練用データの中のパターンで、テストデータには一般的ではないパターンを学習します。我々は、過学習と学習不足の中間を目指す必要があります。これから見ていくように、ちょうどよいエポック数だけ訓練を行うというのは必要なスキルなのです。

過学習を防止するための、最良の解決策は、より多くの訓練用データを使うことです。多くのデータで訓練を行えば行うほど、モデルは自然により汎化していく様になります。これが不可能な場合、次善の策は正則化のようなテクニックを使うことです。正則化は、モデルに保存される情報の量とタイプに制約を課すものです。ネットワークが少数のパターンしか記憶できなければ、最適化プロセスにより、最も主要なパターンのみを学習することになり、より汎化される可能性が高くなります。

このノートブックでは、重みの正則化とドロップアウトという、よく使われる2つの正則化テクニックをご紹介します。これらを使って、IMDBの映画レビューを分類するノートブックの改善を図ります。

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
2.1.0

IMDBデータセットのダウンロード

以前のノートブックで使用したエンベディングの代わりに、ここでは文をマルチホットエンコードします。このモデルは、訓練用データセットをすぐに過学習します。このモデルを使って、過学習がいつ起きるかということと、どうやって過学習と戦うかをデモします。

リストをマルチホットエンコードすると言うのは、0と1のベクトルにするということです。具体的にいうと、例えば[3, 5]というシーケンスを、インデックス3と5の値が1で、それ以外がすべて0の、10,000次元のベクトルに変換するということを意味します。

NUM_WORDS = 10000

(train_data, train_labels), (test_data, test_labels) = keras.datasets.imdb.load_data(num_words=NUM_WORDS)

def multi_hot_sequences(sequences, dimension):
    # 形状が (len(sequences), dimension)ですべて0の行列を作る
    results = np.zeros((len(sequences), dimension))
    for i, word_indices in enumerate(sequences):
        results[i, word_indices] = 1.0  # 特定のインデックスに対してresults[i] を1に設定する
    return results


train_data = multi_hot_sequences(train_data, dimension=NUM_WORDS)
test_data = multi_hot_sequences(test_data, dimension=NUM_WORDS)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
17465344/17464789 [==============================] - 0s 0us/step

結果として得られるマルチホットベクトルの1つを見てみましょう。単語のインデックスは頻度順にソートされています。このため、インデックスが0に近いほど1が多く出現するはずです。分布を見てみましょう。

plt.plot(train_data[0])
[<matplotlib.lines.Line2D at 0x7f5807e5e048>]

png

過学習のデモ

過学習を防止するための最も単純な方法は、モデルのサイズ、すなわち、モデル内の学習可能なパラメータの数を小さくすることです(学習パラメータの数は、層の数と層ごとのユニット数で決まります)。ディープラーニングでは、モデルの学習可能なパラメータ数を、しばしばモデルの「キャパシティ」と呼びます。直感的に考えれば、パラメータ数の多いモデルほど「記憶容量」が大きくなり、訓練用のサンプルとその目的変数の間の辞書のようなマッピングをたやすく学習することができます。このマッピングには汎化能力がまったくなく、これまで見たことが無いデータを使って予測をする際には役に立ちません。

ディープラーニングのモデルは訓練用データに適応しやすいけれど、本当のチャレレンジは汎化であって適応ではないということを、肝に銘じておく必要があります。

一方、ネットワークの記憶容量が限られている場合、前述のようなマッピングを簡単に学習することはできません。損失を減らすためには、より予測能力が高い圧縮された表現を学習しなければなりません。同時に、モデルを小さくしすぎると、訓練用データに適応するのが難しくなります。「多すぎる容量」と「容量不足」の間にちょうどよい容量があるのです。

残念ながら、(層の数や、層ごとの大きさといった)モデルの適切なサイズやアーキテクチャを決める魔法の方程式はありません。一連の異なるアーキテクチャを使って実験を行う必要があります。

適切なモデルのサイズを見つけるには、比較的少ない層の数とパラメータから始めるのがベストです。それから、検証用データでの損失値の改善が見られなくなるまで、徐々に層の大きさを増やしたり、新たな層を加えたりします。映画レビューの分類ネットワークでこれを試してみましょう。

比較基準として、Dense層だけを使ったシンプルなモデルを構築し、その後、それより小さいバージョンと大きいバージョンを作って比較します。

比較基準を作る

baseline_model = keras.Sequential([
    # `.summary` を見るために`input_shape`が必要 
    keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

baseline_model.compile(optimizer='adam',
                       loss='binary_crossentropy',
                       metrics=['accuracy', 'binary_crossentropy'])

baseline_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                160016    
_________________________________________________________________
dense_1 (Dense)              (None, 16)                272       
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 17        
=================================================================
Total params: 160,305
Trainable params: 160,305
Non-trainable params: 0
_________________________________________________________________
baseline_history = baseline_model.fit(train_data,
                                      train_labels,
                                      epochs=20,
                                      batch_size=512,
                                      validation_data=(test_data, test_labels),
                                      verbose=2)
Train on 25000 samples, validate on 25000 samples
Epoch 1/20
25000/25000 - 2s - loss: 0.5325 - accuracy: 0.7494 - binary_crossentropy: 0.5325 - val_loss: 0.3694 - val_accuracy: 0.8734 - val_binary_crossentropy: 0.3694
Epoch 2/20
25000/25000 - 2s - loss: 0.2631 - accuracy: 0.9084 - binary_crossentropy: 0.2631 - val_loss: 0.2871 - val_accuracy: 0.8864 - val_binary_crossentropy: 0.2871
Epoch 3/20
25000/25000 - 2s - loss: 0.1868 - accuracy: 0.9359 - binary_crossentropy: 0.1868 - val_loss: 0.2918 - val_accuracy: 0.8850 - val_binary_crossentropy: 0.2918
Epoch 4/20
25000/25000 - 2s - loss: 0.1495 - accuracy: 0.9492 - binary_crossentropy: 0.1495 - val_loss: 0.3111 - val_accuracy: 0.8799 - val_binary_crossentropy: 0.3111
Epoch 5/20
25000/25000 - 2s - loss: 0.1192 - accuracy: 0.9628 - binary_crossentropy: 0.1192 - val_loss: 0.3351 - val_accuracy: 0.8770 - val_binary_crossentropy: 0.3351
Epoch 6/20
25000/25000 - 2s - loss: 0.0979 - accuracy: 0.9707 - binary_crossentropy: 0.0979 - val_loss: 0.3702 - val_accuracy: 0.8718 - val_binary_crossentropy: 0.3702
Epoch 7/20
25000/25000 - 2s - loss: 0.0784 - accuracy: 0.9785 - binary_crossentropy: 0.0784 - val_loss: 0.4046 - val_accuracy: 0.8680 - val_binary_crossentropy: 0.4046
Epoch 8/20
25000/25000 - 2s - loss: 0.0613 - accuracy: 0.9859 - binary_crossentropy: 0.0613 - val_loss: 0.4464 - val_accuracy: 0.8679 - val_binary_crossentropy: 0.4464
Epoch 9/20
25000/25000 - 2s - loss: 0.0465 - accuracy: 0.9912 - binary_crossentropy: 0.0465 - val_loss: 0.4867 - val_accuracy: 0.8629 - val_binary_crossentropy: 0.4867
Epoch 10/20
25000/25000 - 2s - loss: 0.0345 - accuracy: 0.9946 - binary_crossentropy: 0.0345 - val_loss: 0.5367 - val_accuracy: 0.8615 - val_binary_crossentropy: 0.5367
Epoch 11/20
25000/25000 - 2s - loss: 0.0257 - accuracy: 0.9963 - binary_crossentropy: 0.0257 - val_loss: 0.5692 - val_accuracy: 0.8605 - val_binary_crossentropy: 0.5692
Epoch 12/20
25000/25000 - 2s - loss: 0.0188 - accuracy: 0.9979 - binary_crossentropy: 0.0188 - val_loss: 0.6034 - val_accuracy: 0.8580 - val_binary_crossentropy: 0.6034
Epoch 13/20
25000/25000 - 2s - loss: 0.0141 - accuracy: 0.9987 - binary_crossentropy: 0.0141 - val_loss: 0.6383 - val_accuracy: 0.8579 - val_binary_crossentropy: 0.6383
Epoch 14/20
25000/25000 - 2s - loss: 0.0108 - accuracy: 0.9992 - binary_crossentropy: 0.0108 - val_loss: 0.6699 - val_accuracy: 0.8580 - val_binary_crossentropy: 0.6699
Epoch 15/20
25000/25000 - 2s - loss: 0.0085 - accuracy: 0.9995 - binary_crossentropy: 0.0085 - val_loss: 0.6945 - val_accuracy: 0.8562 - val_binary_crossentropy: 0.6945
Epoch 16/20
25000/25000 - 2s - loss: 0.0068 - accuracy: 0.9995 - binary_crossentropy: 0.0068 - val_loss: 0.7197 - val_accuracy: 0.8566 - val_binary_crossentropy: 0.7197
Epoch 17/20
25000/25000 - 2s - loss: 0.0055 - accuracy: 0.9996 - binary_crossentropy: 0.0055 - val_loss: 0.7477 - val_accuracy: 0.8575 - val_binary_crossentropy: 0.7477
Epoch 18/20
25000/25000 - 2s - loss: 0.0045 - accuracy: 0.9996 - binary_crossentropy: 0.0045 - val_loss: 0.7686 - val_accuracy: 0.8566 - val_binary_crossentropy: 0.7686
Epoch 19/20
25000/25000 - 2s - loss: 0.0037 - accuracy: 0.9996 - binary_crossentropy: 0.0037 - val_loss: 0.7877 - val_accuracy: 0.8565 - val_binary_crossentropy: 0.7877
Epoch 20/20
25000/25000 - 2s - loss: 0.0030 - accuracy: 0.9998 - binary_crossentropy: 0.0030 - val_loss: 0.8059 - val_accuracy: 0.8561 - val_binary_crossentropy: 0.8059

より小さいモデルの構築

今作成したばかりの比較基準となるモデルに比べて隠れユニット数が少ないモデルを作りましょう。

smaller_model = keras.Sequential([
    keras.layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(4, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

smaller_model.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['accuracy', 'binary_crossentropy'])

smaller_model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 4)                 40004     
_________________________________________________________________
dense_4 (Dense)              (None, 4)                 20        
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 5         
=================================================================
Total params: 40,029
Trainable params: 40,029
Non-trainable params: 0
_________________________________________________________________

同じデータを使って訓練します。

smaller_history = smaller_model.fit(train_data,
                                    train_labels,
                                    epochs=20,
                                    batch_size=512,
                                    validation_data=(test_data, test_labels),
                                    verbose=2)
Train on 25000 samples, validate on 25000 samples
Epoch 1/20
25000/25000 - 2s - loss: 0.6501 - accuracy: 0.5962 - binary_crossentropy: 0.6501 - val_loss: 0.6008 - val_accuracy: 0.6915 - val_binary_crossentropy: 0.6008
Epoch 2/20
25000/25000 - 2s - loss: 0.5463 - accuracy: 0.7792 - binary_crossentropy: 0.5463 - val_loss: 0.5227 - val_accuracy: 0.7875 - val_binary_crossentropy: 0.5227
Epoch 3/20
25000/25000 - 2s - loss: 0.4742 - accuracy: 0.8597 - binary_crossentropy: 0.4742 - val_loss: 0.4771 - val_accuracy: 0.8495 - val_binary_crossentropy: 0.4771
Epoch 4/20
25000/25000 - 2s - loss: 0.4277 - accuracy: 0.8989 - binary_crossentropy: 0.4277 - val_loss: 0.4549 - val_accuracy: 0.8552 - val_binary_crossentropy: 0.4549
Epoch 5/20
25000/25000 - 2s - loss: 0.3939 - accuracy: 0.9197 - binary_crossentropy: 0.3939 - val_loss: 0.4385 - val_accuracy: 0.8682 - val_binary_crossentropy: 0.4385
Epoch 6/20
25000/25000 - 2s - loss: 0.3663 - accuracy: 0.9361 - binary_crossentropy: 0.3663 - val_loss: 0.4380 - val_accuracy: 0.8615 - val_binary_crossentropy: 0.4380
Epoch 7/20
25000/25000 - 2s - loss: 0.3422 - accuracy: 0.9480 - binary_crossentropy: 0.3422 - val_loss: 0.4365 - val_accuracy: 0.8621 - val_binary_crossentropy: 0.4365
Epoch 8/20
25000/25000 - 2s - loss: 0.3209 - accuracy: 0.9580 - binary_crossentropy: 0.3209 - val_loss: 0.4301 - val_accuracy: 0.8680 - val_binary_crossentropy: 0.4301
Epoch 9/20
25000/25000 - 2s - loss: 0.3019 - accuracy: 0.9641 - binary_crossentropy: 0.3019 - val_loss: 0.4292 - val_accuracy: 0.8695 - val_binary_crossentropy: 0.4292
Epoch 10/20
25000/25000 - 2s - loss: 0.2850 - accuracy: 0.9706 - binary_crossentropy: 0.2850 - val_loss: 0.4264 - val_accuracy: 0.8705 - val_binary_crossentropy: 0.4264
Epoch 11/20
25000/25000 - 2s - loss: 0.2693 - accuracy: 0.9749 - binary_crossentropy: 0.2693 - val_loss: 0.4283 - val_accuracy: 0.8699 - val_binary_crossentropy: 0.4283
Epoch 12/20
25000/25000 - 2s - loss: 0.2554 - accuracy: 0.9779 - binary_crossentropy: 0.2554 - val_loss: 0.4373 - val_accuracy: 0.8676 - val_binary_crossentropy: 0.4373
Epoch 13/20
25000/25000 - 2s - loss: 0.2423 - accuracy: 0.9810 - binary_crossentropy: 0.2423 - val_loss: 0.4464 - val_accuracy: 0.8645 - val_binary_crossentropy: 0.4464
Epoch 14/20
25000/25000 - 2s - loss: 0.2304 - accuracy: 0.9840 - binary_crossentropy: 0.2304 - val_loss: 0.4520 - val_accuracy: 0.8636 - val_binary_crossentropy: 0.4520
Epoch 15/20
25000/25000 - 2s - loss: 0.2197 - accuracy: 0.9857 - binary_crossentropy: 0.2197 - val_loss: 0.4568 - val_accuracy: 0.8636 - val_binary_crossentropy: 0.4568
Epoch 16/20
25000/25000 - 2s - loss: 0.2097 - accuracy: 0.9870 - binary_crossentropy: 0.2097 - val_loss: 0.4739 - val_accuracy: 0.8604 - val_binary_crossentropy: 0.4739
Epoch 17/20
25000/25000 - 2s - loss: 0.2007 - accuracy: 0.9881 - binary_crossentropy: 0.2007 - val_loss: 0.4794 - val_accuracy: 0.8607 - val_binary_crossentropy: 0.4794
Epoch 18/20
25000/25000 - 2s - loss: 0.1924 - accuracy: 0.9887 - binary_crossentropy: 0.1924 - val_loss: 0.4953 - val_accuracy: 0.8589 - val_binary_crossentropy: 0.4953
Epoch 19/20
25000/25000 - 2s - loss: 0.1847 - accuracy: 0.9895 - binary_crossentropy: 0.1847 - val_loss: 0.4984 - val_accuracy: 0.8596 - val_binary_crossentropy: 0.4984
Epoch 20/20
25000/25000 - 2s - loss: 0.1776 - accuracy: 0.9899 - binary_crossentropy: 0.1776 - val_loss: 0.5054 - val_accuracy: 0.8589 - val_binary_crossentropy: 0.5054

より大きなモデルの構築

練習として、より大きなモデルを作成し、どれほど急速に過学習が起きるかを見ることもできます。次はこのベンチマークに、この問題が必要とするよりはるかに容量の大きなネットワークを追加しましょう。

bigger_model = keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(512, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

bigger_model.compile(optimizer='adam',
                     loss='binary_crossentropy',
                     metrics=['accuracy','binary_crossentropy'])

bigger_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_6 (Dense)              (None, 512)               5120512   
_________________________________________________________________
dense_7 (Dense)              (None, 512)               262656    
_________________________________________________________________
dense_8 (Dense)              (None, 1)                 513       
=================================================================
Total params: 5,383,681
Trainable params: 5,383,681
Non-trainable params: 0
_________________________________________________________________

このモデルもまた同じデータを使って訓練します。

bigger_history = bigger_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)
Train on 25000 samples, validate on 25000 samples
Epoch 1/20
25000/25000 - 2s - loss: 0.3574 - accuracy: 0.8481 - binary_crossentropy: 0.3574 - val_loss: 0.2974 - val_accuracy: 0.8789 - val_binary_crossentropy: 0.2974
Epoch 2/20
25000/25000 - 2s - loss: 0.1443 - accuracy: 0.9480 - binary_crossentropy: 0.1443 - val_loss: 0.3261 - val_accuracy: 0.8739 - val_binary_crossentropy: 0.3261
Epoch 3/20
25000/25000 - 2s - loss: 0.0503 - accuracy: 0.9853 - binary_crossentropy: 0.0503 - val_loss: 0.4290 - val_accuracy: 0.8712 - val_binary_crossentropy: 0.4290
Epoch 4/20
25000/25000 - 2s - loss: 0.0088 - accuracy: 0.9984 - binary_crossentropy: 0.0088 - val_loss: 0.5607 - val_accuracy: 0.8705 - val_binary_crossentropy: 0.5607
Epoch 5/20
25000/25000 - 2s - loss: 0.0017 - accuracy: 1.0000 - binary_crossentropy: 0.0017 - val_loss: 0.6586 - val_accuracy: 0.8700 - val_binary_crossentropy: 0.6586
Epoch 6/20
25000/25000 - 2s - loss: 3.0520e-04 - accuracy: 1.0000 - binary_crossentropy: 3.0520e-04 - val_loss: 0.7091 - val_accuracy: 0.8708 - val_binary_crossentropy: 0.7091
Epoch 7/20
25000/25000 - 2s - loss: 1.6983e-04 - accuracy: 1.0000 - binary_crossentropy: 1.6983e-04 - val_loss: 0.7427 - val_accuracy: 0.8709 - val_binary_crossentropy: 0.7427
Epoch 8/20
25000/25000 - 2s - loss: 1.1796e-04 - accuracy: 1.0000 - binary_crossentropy: 1.1796e-04 - val_loss: 0.7677 - val_accuracy: 0.8713 - val_binary_crossentropy: 0.7677
Epoch 9/20
25000/25000 - 2s - loss: 8.8607e-05 - accuracy: 1.0000 - binary_crossentropy: 8.8607e-05 - val_loss: 0.7881 - val_accuracy: 0.8710 - val_binary_crossentropy: 0.7881
Epoch 10/20
25000/25000 - 2s - loss: 6.9536e-05 - accuracy: 1.0000 - binary_crossentropy: 6.9536e-05 - val_loss: 0.8052 - val_accuracy: 0.8714 - val_binary_crossentropy: 0.8052
Epoch 11/20
25000/25000 - 2s - loss: 5.6150e-05 - accuracy: 1.0000 - binary_crossentropy: 5.6150e-05 - val_loss: 0.8203 - val_accuracy: 0.8714 - val_binary_crossentropy: 0.8203
Epoch 12/20
25000/25000 - 2s - loss: 4.6331e-05 - accuracy: 1.0000 - binary_crossentropy: 4.6331e-05 - val_loss: 0.8338 - val_accuracy: 0.8717 - val_binary_crossentropy: 0.8338
Epoch 13/20
25000/25000 - 2s - loss: 3.8876e-05 - accuracy: 1.0000 - binary_crossentropy: 3.8876e-05 - val_loss: 0.8465 - val_accuracy: 0.8719 - val_binary_crossentropy: 0.8465
Epoch 14/20
25000/25000 - 2s - loss: 3.3118e-05 - accuracy: 1.0000 - binary_crossentropy: 3.3118e-05 - val_loss: 0.8581 - val_accuracy: 0.8719 - val_binary_crossentropy: 0.8581
Epoch 15/20
25000/25000 - 2s - loss: 2.8496e-05 - accuracy: 1.0000 - binary_crossentropy: 2.8496e-05 - val_loss: 0.8687 - val_accuracy: 0.8721 - val_binary_crossentropy: 0.8687
Epoch 16/20
25000/25000 - 2s - loss: 2.4789e-05 - accuracy: 1.0000 - binary_crossentropy: 2.4789e-05 - val_loss: 0.8789 - val_accuracy: 0.8720 - val_binary_crossentropy: 0.8789
Epoch 17/20
25000/25000 - 2s - loss: 2.1728e-05 - accuracy: 1.0000 - binary_crossentropy: 2.1728e-05 - val_loss: 0.8880 - val_accuracy: 0.8724 - val_binary_crossentropy: 0.8880
Epoch 18/20
25000/25000 - 2s - loss: 1.9195e-05 - accuracy: 1.0000 - binary_crossentropy: 1.9195e-05 - val_loss: 0.8972 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.8972
Epoch 19/20
25000/25000 - 2s - loss: 1.7054e-05 - accuracy: 1.0000 - binary_crossentropy: 1.7054e-05 - val_loss: 0.9055 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.9055
Epoch 20/20
25000/25000 - 2s - loss: 1.5252e-05 - accuracy: 1.0000 - binary_crossentropy: 1.5252e-05 - val_loss: 0.9139 - val_accuracy: 0.8722 - val_binary_crossentropy: 0.9139

訓練時と検証時の損失をグラフにする

実線は訓練用データセットの損失、破線は検証用データセットでの損失です(検証用データでの損失が小さい方が良いモデルです)。これをみると、小さいネットワークのほうが比較基準のモデルよりも過学習が始まるのが遅いことがわかります(4エポックではなく6エポック後)。また、過学習が始まっても性能の低下がよりゆっくりしています。

def plot_history(histories, key='binary_crossentropy'):
  plt.figure(figsize=(16,10))
    
  for name, history in histories:
    val = plt.plot(history.epoch, history.history['val_'+key],
                   '--', label=name.title()+' Val')
    plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
             label=name.title()+' Train')

  plt.xlabel('Epochs')
  plt.ylabel(key.replace('_',' ').title())
  plt.legend()

  plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
              ('smaller', smaller_history),
              ('bigger', bigger_history)])

png

より大きなネットワークでは、すぐに、1エポックで過学習が始まり、その度合も強いことに注目してください。ネットワークの容量が大きいほど訓練用データをモデル化するスピードが早くなり(結果として訓練時の損失値が小さくなり)ますが、より過学習しやすく(結果として訓練時の損失値と検証時の損失値が大きく乖離しやすく)なります。

過学習防止の戦略

重みの正則化を加える

「オッカムの剃刀」の原則をご存知でしょうか。何かの説明が2つあるとすると、最も正しいと考えられる説明は、仮定の数が最も少ない「一番単純な」説明だというものです。この原則は、ニューラルネットワークを使って学習されたモデルにも当てはまります。ある訓練用データとネットワーク構造があって、そのデータを説明できる重みの集合が複数ある時(つまり、複数のモデルがある時)、単純なモデルのほうが複雑なものよりも過学習しにくいのです。

ここで言う「単純なモデル」とは、パラメータ値の分布のエントロピーが小さいもの(あるいは、上記で見たように、そもそもパラメータの数が少ないもの)です。したがって、過学習を緩和するための一般的な手法は、重みが小さい値のみをとることで、重み値の分布がより整然となる(正則)様に制約を与えるものです。これを「重みの正則化」と呼ばれ、ネットワークの損失関数に、重みの大きさに関連するコストを加えることで行われます。このコストには2つの種類があります。

  • L1正則化 重み係数の絶対値に比例するコストを加える(重みの「L1ノルム」と呼ばれる)。

  • L2正則化 重み係数の二乗に比例するコストを加える(重み係数の二乗「L2ノルム」と呼ばれる)。L2正則化はニューラルネットワーク用語では重み減衰(Weight Decay)と呼ばれる。呼び方が違うので混乱しないように。重み減衰は数学的にはL2正則化と同義である。

L1正則化は重みパラメータの一部を0にすることでモデルを疎にする効果があります。L2正則化は重みパラメータにペナルティを加えますがモデルを疎にすることはありません。これは、L2正則化のほうが一般的である理由の一つです。

tf.kerasでは、重みの正則化をするために、重み正則化のインスタンスをキーワード引数として層に加えます。ここでは、L2正則化を追加してみましょう。

l2_model = keras.models.Sequential([
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

l2_model.compile(optimizer='adam',
                 loss='binary_crossentropy',
                 metrics=['accuracy', 'binary_crossentropy'])

l2_model_history = l2_model.fit(train_data, train_labels,
                                epochs=20,
                                batch_size=512,
                                validation_data=(test_data, test_labels),
                                verbose=2)
Train on 25000 samples, validate on 25000 samples
Epoch 1/20
25000/25000 - 2s - loss: 0.5173 - accuracy: 0.8109 - binary_crossentropy: 0.4797 - val_loss: 0.3759 - val_accuracy: 0.8741 - val_binary_crossentropy: 0.3367
Epoch 2/20
25000/25000 - 2s - loss: 0.2998 - accuracy: 0.9075 - binary_crossentropy: 0.2562 - val_loss: 0.3316 - val_accuracy: 0.8882 - val_binary_crossentropy: 0.2846
Epoch 3/20
25000/25000 - 2s - loss: 0.2504 - accuracy: 0.9290 - binary_crossentropy: 0.2006 - val_loss: 0.3354 - val_accuracy: 0.8865 - val_binary_crossentropy: 0.2837
Epoch 4/20
25000/25000 - 2s - loss: 0.2262 - accuracy: 0.9404 - binary_crossentropy: 0.1725 - val_loss: 0.3502 - val_accuracy: 0.8815 - val_binary_crossentropy: 0.2952
Epoch 5/20
25000/25000 - 2s - loss: 0.2103 - accuracy: 0.9478 - binary_crossentropy: 0.1537 - val_loss: 0.3690 - val_accuracy: 0.8765 - val_binary_crossentropy: 0.3113
Epoch 6/20
25000/25000 - 2s - loss: 0.1998 - accuracy: 0.9522 - binary_crossentropy: 0.1406 - val_loss: 0.3772 - val_accuracy: 0.8769 - val_binary_crossentropy: 0.3168
Epoch 7/20
25000/25000 - 2s - loss: 0.1889 - accuracy: 0.9581 - binary_crossentropy: 0.1275 - val_loss: 0.4048 - val_accuracy: 0.8704 - val_binary_crossentropy: 0.3426
Epoch 8/20
25000/25000 - 2s - loss: 0.1797 - accuracy: 0.9613 - binary_crossentropy: 0.1164 - val_loss: 0.4049 - val_accuracy: 0.8736 - val_binary_crossentropy: 0.3409
Epoch 9/20
25000/25000 - 2s - loss: 0.1715 - accuracy: 0.9656 - binary_crossentropy: 0.1071 - val_loss: 0.4209 - val_accuracy: 0.8692 - val_binary_crossentropy: 0.3558
Epoch 10/20
25000/25000 - 2s - loss: 0.1646 - accuracy: 0.9707 - binary_crossentropy: 0.0990 - val_loss: 0.4330 - val_accuracy: 0.8685 - val_binary_crossentropy: 0.3669
Epoch 11/20
25000/25000 - 2s - loss: 0.1588 - accuracy: 0.9708 - binary_crossentropy: 0.0918 - val_loss: 0.4555 - val_accuracy: 0.8629 - val_binary_crossentropy: 0.3881
Epoch 12/20
25000/25000 - 2s - loss: 0.1505 - accuracy: 0.9772 - binary_crossentropy: 0.0827 - val_loss: 0.4572 - val_accuracy: 0.8658 - val_binary_crossentropy: 0.3891
Epoch 13/20
25000/25000 - 2s - loss: 0.1437 - accuracy: 0.9789 - binary_crossentropy: 0.0757 - val_loss: 0.4718 - val_accuracy: 0.8633 - val_binary_crossentropy: 0.4036
Epoch 14/20
25000/25000 - 2s - loss: 0.1384 - accuracy: 0.9812 - binary_crossentropy: 0.0701 - val_loss: 0.4812 - val_accuracy: 0.8636 - val_binary_crossentropy: 0.4128
Epoch 15/20
25000/25000 - 2s - loss: 0.1370 - accuracy: 0.9818 - binary_crossentropy: 0.0679 - val_loss: 0.4921 - val_accuracy: 0.8621 - val_binary_crossentropy: 0.4224
Epoch 16/20
25000/25000 - 2s - loss: 0.1392 - accuracy: 0.9799 - binary_crossentropy: 0.0687 - val_loss: 0.5116 - val_accuracy: 0.8610 - val_binary_crossentropy: 0.4402
Epoch 17/20
25000/25000 - 2s - loss: 0.1322 - accuracy: 0.9828 - binary_crossentropy: 0.0606 - val_loss: 0.5161 - val_accuracy: 0.8594 - val_binary_crossentropy: 0.4445
Epoch 18/20
25000/25000 - 2s - loss: 0.1233 - accuracy: 0.9874 - binary_crossentropy: 0.0524 - val_loss: 0.5213 - val_accuracy: 0.8610 - val_binary_crossentropy: 0.4507
Epoch 19/20
25000/25000 - 2s - loss: 0.1194 - accuracy: 0.9896 - binary_crossentropy: 0.0491 - val_loss: 0.5335 - val_accuracy: 0.8590 - val_binary_crossentropy: 0.4635
Epoch 20/20
25000/25000 - 2s - loss: 0.1148 - accuracy: 0.9908 - binary_crossentropy: 0.0452 - val_loss: 0.5394 - val_accuracy: 0.8607 - val_binary_crossentropy: 0.4701

l2(0.001)というのは、層の重み行列の係数全てに対して0.001 * 重み係数の値 **2をネットワークの損失値合計に加えることを意味します。このペナルティは訓練時のみに加えられるため、このネットワークの損失値は、訓練時にはテスト時に比べて大きくなることに注意してください。

L2正則化の影響を見てみましょう。

plot_history([('baseline', baseline_history),
              ('l2', l2_model_history)])

png

ご覧のように、L2正則化ありのモデルは比較基準のモデルに比べて過学習しにくくなっています。両方のモデルのパラメータ数は同じであるにもかかわらずです。

ドロップアウトを追加する

ドロップアウトは、ニューラルネットワークの正則化テクニックとして最もよく使われる手法の一つです。この手法は、トロント大学のヒントンと彼の学生が開発したものです。ドロップアウトは層に適用するもので、訓練時に層から出力された特徴量に対してランダムに「ドロップアウト(つまりゼロ化)」を行うものです。例えば、ある層が訓練時にある入力サンプルに対して、普通は[0.2, 0.5, 1.3, 0.8, 1.1] というベクトルを出力するとします。ドロップアウトを適用すると、このベクトルは例えば[0, 0.5, 1.3, 0, 1.1]のようにランダムに散らばったいくつかのゼロを含むようになります。「ドロップアウト率」はゼロ化される特徴の割合で、通常は0.2から0.5の間に設定します。テスト時は、どのユニットもドロップアウトされず、代わりに出力値がドロップアウト率と同じ比率でスケールダウンされます。これは、訓練時に比べてたくさんのユニットがアクティブであることに対してバランスをとるためです。

tf.kerasでは、Dropout層を使ってドロップアウトをネットワークに導入できます。ドロップアウト層は、その直前の層の出力に対してドロップアウトを適用します。

それでは、IMDBネットワークに2つのドロップアウト層を追加しましょう。

dpt_model = keras.models.Sequential([
    keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1, activation='sigmoid')
])

dpt_model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy','binary_crossentropy'])

dpt_model_history = dpt_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)
Train on 25000 samples, validate on 25000 samples
Epoch 1/20
25000/25000 - 2s - loss: 0.6273 - accuracy: 0.6470 - binary_crossentropy: 0.6273 - val_loss: 0.4964 - val_accuracy: 0.8526 - val_binary_crossentropy: 0.4964
Epoch 2/20
25000/25000 - 2s - loss: 0.4786 - accuracy: 0.7974 - binary_crossentropy: 0.4786 - val_loss: 0.3612 - val_accuracy: 0.8795 - val_binary_crossentropy: 0.3612
Epoch 3/20
25000/25000 - 2s - loss: 0.3765 - accuracy: 0.8567 - binary_crossentropy: 0.3765 - val_loss: 0.3029 - val_accuracy: 0.8867 - val_binary_crossentropy: 0.3029
Epoch 4/20
25000/25000 - 2s - loss: 0.3101 - accuracy: 0.8911 - binary_crossentropy: 0.3101 - val_loss: 0.2785 - val_accuracy: 0.8895 - val_binary_crossentropy: 0.2785
Epoch 5/20
25000/25000 - 2s - loss: 0.2644 - accuracy: 0.9096 - binary_crossentropy: 0.2644 - val_loss: 0.2802 - val_accuracy: 0.8870 - val_binary_crossentropy: 0.2802
Epoch 6/20
25000/25000 - 2s - loss: 0.2267 - accuracy: 0.9235 - binary_crossentropy: 0.2267 - val_loss: 0.2841 - val_accuracy: 0.8882 - val_binary_crossentropy: 0.2841
Epoch 7/20
25000/25000 - 2s - loss: 0.2016 - accuracy: 0.9312 - binary_crossentropy: 0.2016 - val_loss: 0.2898 - val_accuracy: 0.8844 - val_binary_crossentropy: 0.2898
Epoch 8/20
25000/25000 - 2s - loss: 0.1792 - accuracy: 0.9419 - binary_crossentropy: 0.1792 - val_loss: 0.3032 - val_accuracy: 0.8853 - val_binary_crossentropy: 0.3032
Epoch 9/20
25000/25000 - 2s - loss: 0.1625 - accuracy: 0.9452 - binary_crossentropy: 0.1625 - val_loss: 0.3379 - val_accuracy: 0.8826 - val_binary_crossentropy: 0.3379
Epoch 10/20
25000/25000 - 2s - loss: 0.1481 - accuracy: 0.9505 - binary_crossentropy: 0.1481 - val_loss: 0.3473 - val_accuracy: 0.8821 - val_binary_crossentropy: 0.3473
Epoch 11/20
25000/25000 - 2s - loss: 0.1334 - accuracy: 0.9543 - binary_crossentropy: 0.1334 - val_loss: 0.3567 - val_accuracy: 0.8814 - val_binary_crossentropy: 0.3567
Epoch 12/20
25000/25000 - 2s - loss: 0.1222 - accuracy: 0.9587 - binary_crossentropy: 0.1222 - val_loss: 0.3868 - val_accuracy: 0.8798 - val_binary_crossentropy: 0.3868
Epoch 13/20
25000/25000 - 2s - loss: 0.1148 - accuracy: 0.9611 - binary_crossentropy: 0.1148 - val_loss: 0.4045 - val_accuracy: 0.8800 - val_binary_crossentropy: 0.4045
Epoch 14/20
25000/25000 - 2s - loss: 0.1059 - accuracy: 0.9642 - binary_crossentropy: 0.1059 - val_loss: 0.4294 - val_accuracy: 0.8784 - val_binary_crossentropy: 0.4294
Epoch 15/20
25000/25000 - 2s - loss: 0.0998 - accuracy: 0.9647 - binary_crossentropy: 0.0998 - val_loss: 0.4613 - val_accuracy: 0.8775 - val_binary_crossentropy: 0.4613
Epoch 16/20
25000/25000 - 2s - loss: 0.0940 - accuracy: 0.9665 - binary_crossentropy: 0.0940 - val_loss: 0.4511 - val_accuracy: 0.8766 - val_binary_crossentropy: 0.4511
Epoch 17/20
25000/25000 - 2s - loss: 0.0897 - accuracy: 0.9694 - binary_crossentropy: 0.0897 - val_loss: 0.4781 - val_accuracy: 0.8763 - val_binary_crossentropy: 0.4781
Epoch 18/20
25000/25000 - 2s - loss: 0.0807 - accuracy: 0.9717 - binary_crossentropy: 0.0807 - val_loss: 0.5248 - val_accuracy: 0.8756 - val_binary_crossentropy: 0.5248
Epoch 19/20
25000/25000 - 2s - loss: 0.0867 - accuracy: 0.9683 - binary_crossentropy: 0.0867 - val_loss: 0.5079 - val_accuracy: 0.8770 - val_binary_crossentropy: 0.5079
Epoch 20/20
25000/25000 - 2s - loss: 0.0798 - accuracy: 0.9703 - binary_crossentropy: 0.0798 - val_loss: 0.5265 - val_accuracy: 0.8752 - val_binary_crossentropy: 0.5265
plot_history([('baseline', baseline_history),
              ('dropout', dpt_model_history)])

png

ドロップアウトを追加することで、比較対象モデルより明らかに改善が見られます。

まとめ:ニューラルネットワークにおいて過学習を防ぐ最も一般的な方法は次のとおりです。

  • 訓練データを増やす
  • ネットワークの容量をへらす
  • 重みの正則化を行う
  • ドロップアウトを追加する

このガイドで触れていない2つの重要なアプローチがあります。データ拡張とバッチ正規化です。

#@title MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.