Explore overfit and underfit

Смотрите на TensorFlow.org Запустите в Google Colab Изучайте код на GitHub Скачайте ноутбук

Как всегда, код в этом примере будет использовать API tf.keras, о котором вы можете узнать больше в руководстве TensorFlow Keras.

В обоих предыдщих примерах с классификацией обзоров фильмов и предсказанием эффективности расхода топлива, мы увидели, что точность нашей модели на проверочных данных достигает пика после обучения за некоторое количество эпох, а затем начинает снижаться.

Другими словами, наша модель переобучилась на тренировочных данных. Важно научиться работать с переобученностью. Хотя часто возможно достичь высокой точности на обучающей выборке, на самом деле мы хотим построить модель которая хорошо обобщается на тестовой выборке (данных которые модель не видела ранее).

Обратным случаем переобучения является недообучение. Недообучение возникает когда еще есть возможность для улучшения модели на тестовых данных. Это может случитья по ряду причин: модель недостаточно сильная, с избыточной регуляризацией или просто недостаточно долго обучалась. Это значит, что сеть не изучила релевантные паттерны в обучающей выборке.

Если ты будешь обучать модель слишком долго, модель начнет переобучаться и настроится на паттерны тренировочных данных которые не обобщаются на тестовые данные. Нам нужно найти баланс. Понимание того, как обучать модель за подходящее количество эпох, как мы выясним ниже - очень полезный навык.

Лучшее решение для предотвращения переобученности - использовать больше тренировочных данных. Модель обученная на большем количестве данных естественным образом обобщает лучше. Когда это более невозможно, следующее решение - использовать техники наподобие регуляризации. Они накладывают ограничения на количество и тип информации которую ваша модель может хранить. Если нейросеть может запомнить только небольшое число паттернов, то процесс оптимизации заставит ее сфокусироваться на наиболее заметных паттернах, у которых более высокий шанс обобщения.

В этом уроке мы познакомимся с двумя распространенными техниками регуляризации - регуляризацией весов и исключением (dropout) и используем их для того, чтобы улучшить нашу модель классификации обзоров фильмов из IMDB.

import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
2.3.0

Загрузите датасет IMDB

Вместо использования вложения (embedding) как в предыдущем уроке, здесь мы используем multi-hot encode предложений. Эта модель быстро переобучится на тренировочных данных. Мы посмотрим как произойдет переобучение и как его предотвратить.

Multi-hot-encoding наших списков означет их преобразование в вектора из 0 и 1. Конкретнее это значит что например последовательность [3, 5] преобразуется в 10 000-мерный вектор, который будет состоять полностью из нулей за исключением индексов 3 и 5 которые будут единицами.

NUM_WORDS = 10000

(train_data, train_labels), (test_data, test_labels) = keras.datasets.imdb.load_data(num_words=NUM_WORDS)

def multi_hot_sequences(sequences, dimension):
    # Создадим нулевую матрицу размерности (len(sequences), dimension)
    results = np.zeros((len(sequences), dimension))
    for i, word_indices in enumerate(sequences):
        results[i, word_indices] = 1.0  # приравняем требуемые индексы results[i] к 1
    return results


train_data = multi_hot_sequences(train_data, dimension=NUM_WORDS)
test_data = multi_hot_sequences(test_data, dimension=NUM_WORDS)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
17465344/17464789 [==============================] - 0s 0us/step

Давайте посмотрим на один из получившихся multi-hot векторов. Индексы слов были отсортированы по частоте поэтому ожидаемо много значений 1 возле нулевого индекса, что мы и видим на этом графике:

plt.plot(train_data[0])
[<matplotlib.lines.Line2D at 0x7fe94b78f470>]

png

Продемонстрируем переобучение

Простейший способ предотвратить переобучение это сократить размер модели, т.е. количество обучаемых параметров модели (которые определяются числом слоев и элементов в каждом слое). В глубоком обучении количество обучаемых параметров модели часто называют "емкостью" модели. Интуитивно, модель с большим количеством параметров будет иметь большую "запоминающую емкость" и поэтому легко сможет выучить идеальный словарь - как отображение между обучающими примерами и их целевыми значениями, отображение безо всякой обобщающей силы. Но это будет бесполезно при прогнозировании на новых, ранее не виденных данных.

Всегда имейте это ввиду: модели глубокого обучения хорошо настраиваются на тренировочных данных, но настоящим вызовом является обобщение, не обучение.

С другой стороны, если нейросеть имеет ограниченные ресурсы ззапоминания, то она не сможет выучить отображение так легко. Для минимизации функции потерь модель вынуждена выучить только сжатые представления у которых больше предсказательной силы. В то же время, если вы сделаете вашу модель слишком маленькой, она с трудом подстроится под тренировочный сет. Существует баланс между "слишком большой емкостью" и "недостаточной емкостью".

К сожалению, не существует магической формулы, чтобы определить правильный размер или архитектуру модели, говоря о количестве слоев или размере каждого слоя. Вам необходимо поэкспериментировать с использованием разных архитектур модели.

Чтобы найди подходящий размер модели лучше начать с относительно небольшого количества слоев и параметров, затем начать увеличивать размер слоев или добавлять новые до тех пор, пока вы не увидите уменьшение отдачи функции ошибок на проверочных данных. Давай попробуем это на примере нашей сети для классификации обзоров фильмов.

Мы построим простую модель используя только Dense слои в качестве базовой, затем создадим меньшую и большую версии модели и сравним их.

Создайте базовую модель

baseline_model = keras.Sequential([
    # Параметр `input_shape` нужен только для того, чтобы заработал `.summary`.
    keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

baseline_model.compile(optimizer='adam',
                       loss='binary_crossentropy',
                       metrics=['accuracy', 'binary_crossentropy'])

baseline_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                160016    
_________________________________________________________________
dense_1 (Dense)              (None, 16)                272       
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 17        
=================================================================
Total params: 160,305
Trainable params: 160,305
Non-trainable params: 0
_________________________________________________________________

baseline_history = baseline_model.fit(train_data,
                                      train_labels,
                                      epochs=20,
                                      batch_size=512,
                                      validation_data=(test_data, test_labels),
                                      verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.4687 - accuracy: 0.8131 - binary_crossentropy: 0.4687 - val_loss: 0.3265 - val_accuracy: 0.8761 - val_binary_crossentropy: 0.3265
Epoch 2/20
49/49 - 1s - loss: 0.2387 - accuracy: 0.9136 - binary_crossentropy: 0.2387 - val_loss: 0.2847 - val_accuracy: 0.8862 - val_binary_crossentropy: 0.2847
Epoch 3/20
49/49 - 1s - loss: 0.1758 - accuracy: 0.9379 - binary_crossentropy: 0.1758 - val_loss: 0.2928 - val_accuracy: 0.8838 - val_binary_crossentropy: 0.2928
Epoch 4/20
49/49 - 1s - loss: 0.1404 - accuracy: 0.9526 - binary_crossentropy: 0.1404 - val_loss: 0.3198 - val_accuracy: 0.8782 - val_binary_crossentropy: 0.3198
Epoch 5/20
49/49 - 1s - loss: 0.1164 - accuracy: 0.9615 - binary_crossentropy: 0.1164 - val_loss: 0.3491 - val_accuracy: 0.8734 - val_binary_crossentropy: 0.3491
Epoch 6/20
49/49 - 1s - loss: 0.0969 - accuracy: 0.9690 - binary_crossentropy: 0.0969 - val_loss: 0.3836 - val_accuracy: 0.8690 - val_binary_crossentropy: 0.3836
Epoch 7/20
49/49 - 1s - loss: 0.0820 - accuracy: 0.9754 - binary_crossentropy: 0.0820 - val_loss: 0.4256 - val_accuracy: 0.8647 - val_binary_crossentropy: 0.4256
Epoch 8/20
49/49 - 1s - loss: 0.0698 - accuracy: 0.9798 - binary_crossentropy: 0.0698 - val_loss: 0.4648 - val_accuracy: 0.8624 - val_binary_crossentropy: 0.4648
Epoch 9/20
49/49 - 1s - loss: 0.0580 - accuracy: 0.9852 - binary_crossentropy: 0.0580 - val_loss: 0.5067 - val_accuracy: 0.8594 - val_binary_crossentropy: 0.5067
Epoch 10/20
49/49 - 1s - loss: 0.0466 - accuracy: 0.9896 - binary_crossentropy: 0.0466 - val_loss: 0.5516 - val_accuracy: 0.8562 - val_binary_crossentropy: 0.5516
Epoch 11/20
49/49 - 1s - loss: 0.0383 - accuracy: 0.9926 - binary_crossentropy: 0.0383 - val_loss: 0.5991 - val_accuracy: 0.8541 - val_binary_crossentropy: 0.5991
Epoch 12/20
49/49 - 1s - loss: 0.0312 - accuracy: 0.9945 - binary_crossentropy: 0.0312 - val_loss: 0.6492 - val_accuracy: 0.8514 - val_binary_crossentropy: 0.6492
Epoch 13/20
49/49 - 1s - loss: 0.0254 - accuracy: 0.9960 - binary_crossentropy: 0.0254 - val_loss: 0.6900 - val_accuracy: 0.8519 - val_binary_crossentropy: 0.6900
Epoch 14/20
49/49 - 1s - loss: 0.0198 - accuracy: 0.9974 - binary_crossentropy: 0.0198 - val_loss: 0.7380 - val_accuracy: 0.8496 - val_binary_crossentropy: 0.7380
Epoch 15/20
49/49 - 1s - loss: 0.0147 - accuracy: 0.9988 - binary_crossentropy: 0.0147 - val_loss: 0.7816 - val_accuracy: 0.8493 - val_binary_crossentropy: 0.7816
Epoch 16/20
49/49 - 1s - loss: 0.0107 - accuracy: 0.9997 - binary_crossentropy: 0.0107 - val_loss: 0.8228 - val_accuracy: 0.8485 - val_binary_crossentropy: 0.8228
Epoch 17/20
49/49 - 1s - loss: 0.0080 - accuracy: 0.9998 - binary_crossentropy: 0.0080 - val_loss: 0.8581 - val_accuracy: 0.8483 - val_binary_crossentropy: 0.8581
Epoch 18/20
49/49 - 1s - loss: 0.0060 - accuracy: 0.9999 - binary_crossentropy: 0.0060 - val_loss: 0.8937 - val_accuracy: 0.8471 - val_binary_crossentropy: 0.8937
Epoch 19/20
49/49 - 1s - loss: 0.0047 - accuracy: 1.0000 - binary_crossentropy: 0.0047 - val_loss: 0.9230 - val_accuracy: 0.8470 - val_binary_crossentropy: 0.9230
Epoch 20/20
49/49 - 1s - loss: 0.0038 - accuracy: 1.0000 - binary_crossentropy: 0.0038 - val_loss: 0.9490 - val_accuracy: 0.8473 - val_binary_crossentropy: 0.9490

Создайте меньшую модель

Давайте построим модель с меньшим количеством скрытых нейронов чтобы сравнить ее с базовой моделью, которую мы только создали:

smaller_model = keras.Sequential([
    keras.layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(4, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

smaller_model.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['accuracy', 'binary_crossentropy'])

smaller_model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 4)                 40004     
_________________________________________________________________
dense_4 (Dense)              (None, 4)                 20        
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 5         
=================================================================
Total params: 40,029
Trainable params: 40,029
Non-trainable params: 0
_________________________________________________________________

И обучим модель используя те же данные:

smaller_history = smaller_model.fit(train_data,
                                    train_labels,
                                    epochs=20,
                                    batch_size=512,
                                    validation_data=(test_data, test_labels),
                                    verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.6131 - accuracy: 0.6892 - binary_crossentropy: 0.6131 - val_loss: 0.5560 - val_accuracy: 0.8004 - val_binary_crossentropy: 0.5560
Epoch 2/20
49/49 - 1s - loss: 0.5063 - accuracy: 0.8364 - binary_crossentropy: 0.5063 - val_loss: 0.5031 - val_accuracy: 0.8335 - val_binary_crossentropy: 0.5031
Epoch 3/20
49/49 - 1s - loss: 0.4535 - accuracy: 0.8838 - binary_crossentropy: 0.4535 - val_loss: 0.4749 - val_accuracy: 0.8613 - val_binary_crossentropy: 0.4749
Epoch 4/20
49/49 - 1s - loss: 0.4171 - accuracy: 0.9106 - binary_crossentropy: 0.4171 - val_loss: 0.4593 - val_accuracy: 0.8632 - val_binary_crossentropy: 0.4593
Epoch 5/20
49/49 - 1s - loss: 0.3888 - accuracy: 0.9263 - binary_crossentropy: 0.3888 - val_loss: 0.4452 - val_accuracy: 0.8760 - val_binary_crossentropy: 0.4452
Epoch 6/20
49/49 - 1s - loss: 0.3643 - accuracy: 0.9384 - binary_crossentropy: 0.3643 - val_loss: 0.4425 - val_accuracy: 0.8706 - val_binary_crossentropy: 0.4425
Epoch 7/20
49/49 - 1s - loss: 0.3429 - accuracy: 0.9492 - binary_crossentropy: 0.3429 - val_loss: 0.4337 - val_accuracy: 0.8759 - val_binary_crossentropy: 0.4337
Epoch 8/20
49/49 - 1s - loss: 0.3243 - accuracy: 0.9560 - binary_crossentropy: 0.3243 - val_loss: 0.4309 - val_accuracy: 0.8753 - val_binary_crossentropy: 0.4309
Epoch 9/20
49/49 - 1s - loss: 0.3079 - accuracy: 0.9616 - binary_crossentropy: 0.3079 - val_loss: 0.4278 - val_accuracy: 0.8753 - val_binary_crossentropy: 0.4278
Epoch 10/20
49/49 - 1s - loss: 0.2922 - accuracy: 0.9657 - binary_crossentropy: 0.2922 - val_loss: 0.4278 - val_accuracy: 0.8734 - val_binary_crossentropy: 0.4278
Epoch 11/20
49/49 - 1s - loss: 0.2774 - accuracy: 0.9710 - binary_crossentropy: 0.2774 - val_loss: 0.4328 - val_accuracy: 0.8714 - val_binary_crossentropy: 0.4328
Epoch 12/20
49/49 - 1s - loss: 0.2647 - accuracy: 0.9741 - binary_crossentropy: 0.2647 - val_loss: 0.4354 - val_accuracy: 0.8699 - val_binary_crossentropy: 0.4354
Epoch 13/20
49/49 - 1s - loss: 0.2526 - accuracy: 0.9770 - binary_crossentropy: 0.2526 - val_loss: 0.4358 - val_accuracy: 0.8694 - val_binary_crossentropy: 0.4358
Epoch 14/20
49/49 - 1s - loss: 0.2414 - accuracy: 0.9792 - binary_crossentropy: 0.2414 - val_loss: 0.4405 - val_accuracy: 0.8682 - val_binary_crossentropy: 0.4405
Epoch 15/20
49/49 - 1s - loss: 0.2308 - accuracy: 0.9812 - binary_crossentropy: 0.2308 - val_loss: 0.4460 - val_accuracy: 0.8673 - val_binary_crossentropy: 0.4460
Epoch 16/20
49/49 - 1s - loss: 0.2208 - accuracy: 0.9829 - binary_crossentropy: 0.2208 - val_loss: 0.4640 - val_accuracy: 0.8649 - val_binary_crossentropy: 0.4640
Epoch 17/20
49/49 - 1s - loss: 0.2126 - accuracy: 0.9836 - binary_crossentropy: 0.2126 - val_loss: 0.4609 - val_accuracy: 0.8650 - val_binary_crossentropy: 0.4609
Epoch 18/20
49/49 - 1s - loss: 0.2041 - accuracy: 0.9850 - binary_crossentropy: 0.2041 - val_loss: 0.4643 - val_accuracy: 0.8648 - val_binary_crossentropy: 0.4643
Epoch 19/20
49/49 - 1s - loss: 0.1958 - accuracy: 0.9860 - binary_crossentropy: 0.1958 - val_loss: 0.4695 - val_accuracy: 0.8643 - val_binary_crossentropy: 0.4695
Epoch 20/20
49/49 - 1s - loss: 0.1887 - accuracy: 0.9869 - binary_crossentropy: 0.1887 - val_loss: 0.4812 - val_accuracy: 0.8634 - val_binary_crossentropy: 0.4812

Создайте большую модель

В качестве упражнения вы можете построить еще большую модель и увидеть как быстро она начнет переобучаться. Далее давайте сравним с эталоном нейросеть которая имеет намного большую емкость чем того требует задача:

bigger_model = keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(512, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

bigger_model.compile(optimizer='adam',
                     loss='binary_crossentropy',
                     metrics=['accuracy','binary_crossentropy'])

bigger_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_6 (Dense)              (None, 512)               5120512   
_________________________________________________________________
dense_7 (Dense)              (None, 512)               262656    
_________________________________________________________________
dense_8 (Dense)              (None, 1)                 513       
=================================================================
Total params: 5,383,681
Trainable params: 5,383,681
Non-trainable params: 0
_________________________________________________________________

И опять обучим модель используя те же данные:

bigger_history = bigger_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.3455 - accuracy: 0.8468 - binary_crossentropy: 0.3455 - val_loss: 0.2857 - val_accuracy: 0.8851 - val_binary_crossentropy: 0.2857
Epoch 2/20
49/49 - 1s - loss: 0.1334 - accuracy: 0.9531 - binary_crossentropy: 0.1334 - val_loss: 0.3244 - val_accuracy: 0.8740 - val_binary_crossentropy: 0.3244
Epoch 3/20
49/49 - 1s - loss: 0.0359 - accuracy: 0.9908 - binary_crossentropy: 0.0359 - val_loss: 0.4516 - val_accuracy: 0.8719 - val_binary_crossentropy: 0.4516
Epoch 4/20
49/49 - 1s - loss: 0.0043 - accuracy: 0.9995 - binary_crossentropy: 0.0043 - val_loss: 0.6226 - val_accuracy: 0.8712 - val_binary_crossentropy: 0.6226
Epoch 5/20
49/49 - 1s - loss: 5.5439e-04 - accuracy: 1.0000 - binary_crossentropy: 5.5439e-04 - val_loss: 0.6803 - val_accuracy: 0.8726 - val_binary_crossentropy: 0.6803
Epoch 6/20
49/49 - 1s - loss: 1.7775e-04 - accuracy: 1.0000 - binary_crossentropy: 1.7775e-04 - val_loss: 0.7234 - val_accuracy: 0.8728 - val_binary_crossentropy: 0.7234
Epoch 7/20
49/49 - 1s - loss: 1.1010e-04 - accuracy: 1.0000 - binary_crossentropy: 1.1010e-04 - val_loss: 0.7514 - val_accuracy: 0.8735 - val_binary_crossentropy: 0.7514
Epoch 8/20
49/49 - 1s - loss: 7.9917e-05 - accuracy: 1.0000 - binary_crossentropy: 7.9917e-05 - val_loss: 0.7748 - val_accuracy: 0.8730 - val_binary_crossentropy: 0.7748
Epoch 9/20
49/49 - 1s - loss: 6.1641e-05 - accuracy: 1.0000 - binary_crossentropy: 6.1641e-05 - val_loss: 0.7924 - val_accuracy: 0.8729 - val_binary_crossentropy: 0.7924
Epoch 10/20
49/49 - 1s - loss: 4.9283e-05 - accuracy: 1.0000 - binary_crossentropy: 4.9283e-05 - val_loss: 0.8091 - val_accuracy: 0.8728 - val_binary_crossentropy: 0.8091
Epoch 11/20
49/49 - 1s - loss: 4.0419e-05 - accuracy: 1.0000 - binary_crossentropy: 4.0419e-05 - val_loss: 0.8229 - val_accuracy: 0.8732 - val_binary_crossentropy: 0.8229
Epoch 12/20
49/49 - 1s - loss: 3.3850e-05 - accuracy: 1.0000 - binary_crossentropy: 3.3850e-05 - val_loss: 0.8360 - val_accuracy: 0.8729 - val_binary_crossentropy: 0.8360
Epoch 13/20
49/49 - 1s - loss: 2.8720e-05 - accuracy: 1.0000 - binary_crossentropy: 2.8720e-05 - val_loss: 0.8487 - val_accuracy: 0.8728 - val_binary_crossentropy: 0.8487
Epoch 14/20
49/49 - 1s - loss: 2.4650e-05 - accuracy: 1.0000 - binary_crossentropy: 2.4650e-05 - val_loss: 0.8589 - val_accuracy: 0.8729 - val_binary_crossentropy: 0.8589
Epoch 15/20
49/49 - 1s - loss: 2.1409e-05 - accuracy: 1.0000 - binary_crossentropy: 2.1409e-05 - val_loss: 0.8693 - val_accuracy: 0.8729 - val_binary_crossentropy: 0.8693
Epoch 16/20
49/49 - 1s - loss: 1.8755e-05 - accuracy: 1.0000 - binary_crossentropy: 1.8755e-05 - val_loss: 0.8792 - val_accuracy: 0.8727 - val_binary_crossentropy: 0.8792
Epoch 17/20
49/49 - 1s - loss: 1.6545e-05 - accuracy: 1.0000 - binary_crossentropy: 1.6545e-05 - val_loss: 0.8881 - val_accuracy: 0.8726 - val_binary_crossentropy: 0.8881
Epoch 18/20
49/49 - 1s - loss: 1.4707e-05 - accuracy: 1.0000 - binary_crossentropy: 1.4707e-05 - val_loss: 0.8966 - val_accuracy: 0.8727 - val_binary_crossentropy: 0.8966
Epoch 19/20
49/49 - 1s - loss: 1.3137e-05 - accuracy: 1.0000 - binary_crossentropy: 1.3137e-05 - val_loss: 0.9054 - val_accuracy: 0.8728 - val_binary_crossentropy: 0.9054
Epoch 20/20
49/49 - 1s - loss: 1.1812e-05 - accuracy: 1.0000 - binary_crossentropy: 1.1812e-05 - val_loss: 0.9127 - val_accuracy: 0.8728 - val_binary_crossentropy: 0.9127

Постройте графики потерь на тренировочных и проверочных данных

Непрерывные линии показывают потери во время обучения, а прерывистые - во время проверки (помни - меньшие потери на проверочных данных указывают на лучшую модель). В нашем случае самая маленькая модель начинает переобучаться позже, чем основная (после 6 эпох вместо 4) и ее показатели ухудшаются гораздо медленее после переобучения.

def plot_history(histories, key='binary_crossentropy'):
  plt.figure(figsize=(16,10))

  for name, history in histories:
    val = plt.plot(history.epoch, history.history['val_'+key],
                   '--', label=name.title()+' Val')
    plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
             label=name.title()+' Train')

  plt.xlabel('Epochs')
  plt.ylabel(key.replace('_',' ').title())
  plt.legend()

  plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
              ('smaller', smaller_history),
              ('bigger', bigger_history)])

png

Обратите внимание, что большая сеть начинает переобучаться почти сразу же после первой эпохи, и переобучение происходит гораздо быстрее. Чем больше емкость модели, тем легче она смоделирует тренировочные данные (и мы получим низкое значение потерь на тренировочных данных). Но в таком случае она будет более чувствительна к переобучению: разница потерь между обучением и проверкой будет очень велика.

Стратегии предотвращения переобучения

Добавить регуляризацию весов

Вам может быть знаком принцип бритвы Оккама: из двух толкований некоторого явления, правильным скорее всего является самое "простое" - то, которое содержит меньше всего предположений. Этот принцип также применим к моделям, обучемым при помощи нейронных сетей: если у наших данных и сетевой архитектуры существует несколько наборов значений весов (несколько моделей) которые могут объяснить данные и более простые модели переобучаются реже, чем сложные.

В этом контексте "простая модель" это модель в которой распределение значений параметров имеет меньшую энтропию (или модель с меньшим количеством параметров, как та которую мы строили выше). Таким образом, для предотвращение переобучения часто используется ограничение сложности сети путем принуждения ее коэфицентов принимать только небольшие значения, что делает распределение весов более "регулярным". Этот метод называется "регуляризация весов": к функции потерь нашей сети мы добавляем штраф (или cost, стоимость) за использование больших весов. Регуляризация бывает двух видов:

  • L1 регуляризация, где добавляемый штраф пропорционален абсолютным значениям коэффициентов весов (т.е. то что называется "L1 нормой" весов).

  • L2 регуляризация, где добавляемый штраф пропорционален квадрату значений коэффициентов весов (т.е. то, что называется квадратом "L2 нормы" весов). L2 регуляризацию также называют сокращением весов в контексте нейросетей. Не дайте разным названиям запутать себя: сокращение весов математически ровно то же самое что и L2 регуляризация.

L1 регуляризация вводит разреженность обнуляя некотороые из ваших весовых параметров. L2 регуляризация оштрафует весовые параметры не делая их разреженными - это одна из причин, почему L2 более распространена.

В tf.keras регуляризация весов добавляется передачей экземпляров регуляризатора слоям в качестве аргумента. Добавим сейчас L2 регуляризатор весов.

l2_model = keras.models.Sequential([
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

l2_model.compile(optimizer='adam',
                 loss='binary_crossentropy',
                 metrics=['accuracy', 'binary_crossentropy'])

l2_model_history = l2_model.fit(train_data, train_labels,
                                epochs=20,
                                batch_size=512,
                                validation_data=(test_data, test_labels),
                                verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.5141 - accuracy: 0.8159 - binary_crossentropy: 0.4718 - val_loss: 0.3744 - val_accuracy: 0.8771 - val_binary_crossentropy: 0.3308
Epoch 2/20
49/49 - 1s - loss: 0.2977 - accuracy: 0.9101 - binary_crossentropy: 0.2506 - val_loss: 0.3356 - val_accuracy: 0.8854 - val_binary_crossentropy: 0.2859
Epoch 3/20
49/49 - 1s - loss: 0.2470 - accuracy: 0.9324 - binary_crossentropy: 0.1952 - val_loss: 0.3377 - val_accuracy: 0.8856 - val_binary_crossentropy: 0.2842
Epoch 4/20
49/49 - 1s - loss: 0.2270 - accuracy: 0.9402 - binary_crossentropy: 0.1718 - val_loss: 0.3507 - val_accuracy: 0.8814 - val_binary_crossentropy: 0.2941
Epoch 5/20
49/49 - 1s - loss: 0.2117 - accuracy: 0.9473 - binary_crossentropy: 0.1539 - val_loss: 0.3755 - val_accuracy: 0.8758 - val_binary_crossentropy: 0.3168
Epoch 6/20
49/49 - 1s - loss: 0.2003 - accuracy: 0.9515 - binary_crossentropy: 0.1407 - val_loss: 0.3802 - val_accuracy: 0.8759 - val_binary_crossentropy: 0.3200
Epoch 7/20
49/49 - 1s - loss: 0.1906 - accuracy: 0.9582 - binary_crossentropy: 0.1298 - val_loss: 0.3940 - val_accuracy: 0.8728 - val_binary_crossentropy: 0.3330
Epoch 8/20
49/49 - 1s - loss: 0.1846 - accuracy: 0.9593 - binary_crossentropy: 0.1231 - val_loss: 0.4068 - val_accuracy: 0.8706 - val_binary_crossentropy: 0.3447
Epoch 9/20
49/49 - 1s - loss: 0.1787 - accuracy: 0.9628 - binary_crossentropy: 0.1160 - val_loss: 0.4221 - val_accuracy: 0.8691 - val_binary_crossentropy: 0.3591
Epoch 10/20
49/49 - 1s - loss: 0.1742 - accuracy: 0.9645 - binary_crossentropy: 0.1107 - val_loss: 0.4351 - val_accuracy: 0.8671 - val_binary_crossentropy: 0.3712
Epoch 11/20
49/49 - 1s - loss: 0.1714 - accuracy: 0.9636 - binary_crossentropy: 0.1072 - val_loss: 0.4531 - val_accuracy: 0.8647 - val_binary_crossentropy: 0.3885
Epoch 12/20
49/49 - 1s - loss: 0.1696 - accuracy: 0.9646 - binary_crossentropy: 0.1044 - val_loss: 0.4656 - val_accuracy: 0.8629 - val_binary_crossentropy: 0.4001
Epoch 13/20
49/49 - 1s - loss: 0.1658 - accuracy: 0.9664 - binary_crossentropy: 0.0996 - val_loss: 0.4756 - val_accuracy: 0.8618 - val_binary_crossentropy: 0.4091
Epoch 14/20
49/49 - 1s - loss: 0.1591 - accuracy: 0.9704 - binary_crossentropy: 0.0924 - val_loss: 0.4876 - val_accuracy: 0.8604 - val_binary_crossentropy: 0.4210
Epoch 15/20
49/49 - 1s - loss: 0.1566 - accuracy: 0.9708 - binary_crossentropy: 0.0899 - val_loss: 0.5046 - val_accuracy: 0.8586 - val_binary_crossentropy: 0.4379
Epoch 16/20
49/49 - 1s - loss: 0.1611 - accuracy: 0.9662 - binary_crossentropy: 0.0935 - val_loss: 0.5163 - val_accuracy: 0.8581 - val_binary_crossentropy: 0.4482
Epoch 17/20
49/49 - 1s - loss: 0.1562 - accuracy: 0.9702 - binary_crossentropy: 0.0878 - val_loss: 0.5324 - val_accuracy: 0.8554 - val_binary_crossentropy: 0.4640
Epoch 18/20
49/49 - 1s - loss: 0.1483 - accuracy: 0.9752 - binary_crossentropy: 0.0799 - val_loss: 0.5420 - val_accuracy: 0.8555 - val_binary_crossentropy: 0.4739
Epoch 19/20
49/49 - 1s - loss: 0.1445 - accuracy: 0.9755 - binary_crossentropy: 0.0766 - val_loss: 0.5488 - val_accuracy: 0.8554 - val_binary_crossentropy: 0.4808
Epoch 20/20
49/49 - 1s - loss: 0.1447 - accuracy: 0.9746 - binary_crossentropy: 0.0765 - val_loss: 0.5629 - val_accuracy: 0.8539 - val_binary_crossentropy: 0.4945

l2(0.001) значит что каждый коэффициент в матрице весов слоя добавит 0.001 * weight_coefficient_value**2 к значению потерь нейросети. Заметьте, что поскольку этот штраф доавляется только во время обучения, потери сети во время этой стадии будут гораздо выше чем во время теста.

Так выглядит влияние регуляризации L2:

plot_history([('baseline', baseline_history),
              ('l2', l2_model_history)])

png

Как вы можете видеть L2 регуляризованная модель стала более устойчивой к переобучению чем базовая модель, несмотря на то что обе модели имеют одинаковое количество параметров.

Добавьте дропаут

Дропаут(исключение) одна из наиболее эффективных и часто используемых техник регуляризации нейросетей разработанная Джеффом Хинтоном и его студентами в университете Торонто. Примененный к слою Dropout состоит из некоторого количества случайно "исключенных" (т.е. приравненных к нулю) во время обучения выходных параметров слоя. Допустим что наш слой возвращает вектор [0.2, 0.5, 1.3, 0.8, 1.1] для некоторых входных данных при обучении; после применения дропаута, в этом векторе появится несколько нулевых значений распределенных случайным образом, например [0, 0.5, 1.3, 0, 1.1]. "Коэффициент дропаута(dropout rate)" это доля признаков которые будут обнулены; его обычно устанавливают между 0.2 и 0.5. Во время теста дропаут не используется, вместо этого выходные данные слоев масштабируются на коэффициент равный коэффициенту дропаута, чтобы сбалансировать тот факт, что во время проверки активно больше нейронов чем во время обучения.

В tf.keras можно ввести дропаут с помощью слоя Dropout, который применяется к выходным данным предыдущего слоя.

Давайте применим два слоя Dropout к нашей нейросети IMDB и посмотрим насколько хорошо она сократит переобучение:

dpt_model = keras.models.Sequential([
    keras.layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1, activation='sigmoid')
])

dpt_model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy','binary_crossentropy'])

dpt_model_history = dpt_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)
Epoch 1/20
49/49 - 2s - loss: 0.6427 - accuracy: 0.6218 - binary_crossentropy: 0.6427 - val_loss: 0.5146 - val_accuracy: 0.8448 - val_binary_crossentropy: 0.5146
Epoch 2/20
49/49 - 1s - loss: 0.4788 - accuracy: 0.8002 - binary_crossentropy: 0.4788 - val_loss: 0.3634 - val_accuracy: 0.8768 - val_binary_crossentropy: 0.3634
Epoch 3/20
49/49 - 1s - loss: 0.3776 - accuracy: 0.8600 - binary_crossentropy: 0.3776 - val_loss: 0.3012 - val_accuracy: 0.8854 - val_binary_crossentropy: 0.3012
Epoch 4/20
49/49 - 1s - loss: 0.3134 - accuracy: 0.8920 - binary_crossentropy: 0.3134 - val_loss: 0.2804 - val_accuracy: 0.8885 - val_binary_crossentropy: 0.2804
Epoch 5/20
49/49 - 1s - loss: 0.2637 - accuracy: 0.9105 - binary_crossentropy: 0.2637 - val_loss: 0.2810 - val_accuracy: 0.8880 - val_binary_crossentropy: 0.2810
Epoch 6/20
49/49 - 1s - loss: 0.2354 - accuracy: 0.9242 - binary_crossentropy: 0.2354 - val_loss: 0.2875 - val_accuracy: 0.8881 - val_binary_crossentropy: 0.2875
Epoch 7/20
49/49 - 1s - loss: 0.2066 - accuracy: 0.9345 - binary_crossentropy: 0.2066 - val_loss: 0.2944 - val_accuracy: 0.8873 - val_binary_crossentropy: 0.2944
Epoch 8/20
49/49 - 1s - loss: 0.1853 - accuracy: 0.9422 - binary_crossentropy: 0.1853 - val_loss: 0.3257 - val_accuracy: 0.8828 - val_binary_crossentropy: 0.3257
Epoch 9/20
49/49 - 1s - loss: 0.1671 - accuracy: 0.9461 - binary_crossentropy: 0.1671 - val_loss: 0.3306 - val_accuracy: 0.8828 - val_binary_crossentropy: 0.3306
Epoch 10/20
49/49 - 1s - loss: 0.1523 - accuracy: 0.9524 - binary_crossentropy: 0.1523 - val_loss: 0.3544 - val_accuracy: 0.8820 - val_binary_crossentropy: 0.3544
Epoch 11/20
49/49 - 1s - loss: 0.1422 - accuracy: 0.9552 - binary_crossentropy: 0.1422 - val_loss: 0.3732 - val_accuracy: 0.8809 - val_binary_crossentropy: 0.3732
Epoch 12/20
49/49 - 1s - loss: 0.1317 - accuracy: 0.9599 - binary_crossentropy: 0.1317 - val_loss: 0.3820 - val_accuracy: 0.8790 - val_binary_crossentropy: 0.3820
Epoch 13/20
49/49 - 1s - loss: 0.1198 - accuracy: 0.9644 - binary_crossentropy: 0.1198 - val_loss: 0.4069 - val_accuracy: 0.8788 - val_binary_crossentropy: 0.4069
Epoch 14/20
49/49 - 1s - loss: 0.1092 - accuracy: 0.9664 - binary_crossentropy: 0.1092 - val_loss: 0.4402 - val_accuracy: 0.8782 - val_binary_crossentropy: 0.4402
Epoch 15/20
49/49 - 1s - loss: 0.1040 - accuracy: 0.9685 - binary_crossentropy: 0.1040 - val_loss: 0.4378 - val_accuracy: 0.8770 - val_binary_crossentropy: 0.4378
Epoch 16/20
49/49 - 1s - loss: 0.1001 - accuracy: 0.9691 - binary_crossentropy: 0.1001 - val_loss: 0.4651 - val_accuracy: 0.8754 - val_binary_crossentropy: 0.4651
Epoch 17/20
49/49 - 1s - loss: 0.0926 - accuracy: 0.9720 - binary_crossentropy: 0.0926 - val_loss: 0.5046 - val_accuracy: 0.8746 - val_binary_crossentropy: 0.5046
Epoch 18/20
49/49 - 1s - loss: 0.0903 - accuracy: 0.9719 - binary_crossentropy: 0.0903 - val_loss: 0.5174 - val_accuracy: 0.8754 - val_binary_crossentropy: 0.5174
Epoch 19/20
49/49 - 1s - loss: 0.0827 - accuracy: 0.9744 - binary_crossentropy: 0.0827 - val_loss: 0.5312 - val_accuracy: 0.8754 - val_binary_crossentropy: 0.5312
Epoch 20/20
49/49 - 1s - loss: 0.0866 - accuracy: 0.9737 - binary_crossentropy: 0.0866 - val_loss: 0.5330 - val_accuracy: 0.8750 - val_binary_crossentropy: 0.5330

plot_history([('baseline', baseline_history),
              ('dropout', dpt_model_history)])

png

Добавление дропаута явно улучает базовую модель.

Подведем итоги - вот самые основные способы предотвращения переобучения нейросетей:

  • Использовать больше данных для обучения.
  • Уменьшить емкость сети.
  • Использовать регуляризацию весов.
  • Добавить дропаут.

Два важных подхода которые не были рассмотрены в данном руководстве это аугментация данных (data-augmentation) и батч-нормализация (batch normalization).

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.