Кастомные колбеки в Keras

Смотрите на TensorFlow.org Запустите в Google Colab Изучайте код на GitHub Скачайте ноутбук

Кастомный колбек это мощный инструмент для настройки поведения модели Keras во время обучения, оценки или вывода, включая чтение/изменение модели Keras. Примеры включают tf.keras.callbacks.TensorBoard где процесс обучения и результаты могут быть экспортированы и визуализированы в TensorBoard, или tf.keras.callbacks.ModelCheckpoint где модель автоматически сохраняется во время обучения, и т.д. В этом руководстве вы узнаете, что такое колбек Keras, когда он будет вызван, что он может делать, и как вы можете построить свой колбек. Ближе к концу руководства будет несколько демонстраций создания пары простых колбек-приложений, чтобы помочь вам начать делать собственый колбек.

Установка

from __future__ import absolute_import, division, print_function, unicode_literals

try:
  # %tensorflow_version существует только в Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

Введение в колбеки Keras

В Keras, Callback это класс python предназначенный для субклассирования и обеспечивающий определенную функциональность, с набором методов, вызываемых на различных этапах обучения (включая начало и конец пакета/эпохи), тестирования и прогнозирования. Колбеки полезны для того, чтобы получить представлениее о внутренних состояниях и статистике модели во время обучения. Вы можете передать список колбеков (в качестве ключевого слова аргумента callbacks) любому из методов tf.keras.Model.fit(), tf.keras.Model.evaluate() и tf.keras.Model.predict(). Методы колбеков будут вызываться на разных этапах обучения/оценки/вывода.

Чтобы начать давайте импортируем tensorflow и определим простую Sequential модель Keras:

# Определим модель Keras чтобы добавить в нее колбеки
def get_model():
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
  model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
  return model

Затем, загрузим данные MNIST из Keras datasets API для обучения и тестирования:

# Загрузим данные MNIST и предобработаем их
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

Сейчас определим простой колбек чтобы отслеживать начало и конец каждого пакета данных. Во время этих вызовов, он будет печатать индекс текущего пакета.

import datetime

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

Предоставленние колбека для методов модели, таких как tf.keras.Model.fit() гарантирует, что методы вызываются на тех этапах:

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          epochs=1,
          steps_per_epoch=5,
          verbose=0,
          callbacks=[MyCustomCallback()])
Training: batch 0 begins at 18:20:05.144712
Training: batch 0 ends at 18:20:05.614571
Training: batch 1 begins at 18:20:05.614747
Training: batch 1 ends at 18:20:05.616564
Training: batch 2 begins at 18:20:05.616677
Training: batch 2 ends at 18:20:05.618441
Training: batch 3 begins at 18:20:05.618563
Training: batch 3 ends at 18:20:05.620313
Training: batch 4 begins at 18:20:05.620407
Training: batch 4 ends at 18:20:05.622052

Методы Model работающие с колбеками

Пользователи могут добавлять список колбеков к следующим методам tf.keras.Model:

fit(), fit_generator()

Обучает модель за фиксированное количество эпох (итерации по датасету, или данные полученные попакетно с помощью генератора Python).

evaluate(), evaluate_generator()

Оценивает модель для имеющихся данных или генератора данных. Выводит значения потерь и метрик во время оценки.

predict(), predict_generator()

Генерирует предсказания для входных данных или генератора данных.

_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5,
          callbacks=[MyCustomCallback()])
Evaluating: batch 0 begins at 18:20:05.676864
Evaluating: batch 0 ends at 18:20:05.746044
Evaluating: batch 1 begins at 18:20:05.746180
Evaluating: batch 1 ends at 18:20:05.747697
Evaluating: batch 2 begins at 18:20:05.747787
Evaluating: batch 2 ends at 18:20:05.749059
Evaluating: batch 3 begins at 18:20:05.749139
Evaluating: batch 3 ends at 18:20:05.750444
Evaluating: batch 4 begins at 18:20:05.750527
Evaluating: batch 4 ends at 18:20:05.751895

Обзор методов колбеков

Общие методы для обучения/тестирования/предсказания

Для обучения, тестирования и предсказания предоставляются следующие методы для переопределения.

on_(train|test|predict)_begin(self, logs=None)

Вызывается в начале fit/evaluate/predict.

on_(train|test|predict)_end(self, logs=None)

Вызывается в конце fit/evaluate/predict.

on_(train|test|predict)_batch_begin(self, batch, logs=None)

Вызывается непосредственно перед обработкой пакета во время обучения/тестирования/предсказания. С этим методом, logs это словарь с ключами batch и size, представляющие номер текущего пакета и размер пакета.

on_(train|test|predict)_batch_end(self, batch, logs=None)

Вызывается в конце обучения/тестирования/предсказания на пакете. С этим методом, logs это словарь содержащий результат метрик с состоянием.

Специфичные для обучения методы

В дополнение, для обучения предоставляются следующие методы.

on_epoch_begin(self, epoch, logs=None)

Вызывается в начале эпохи во время обучения.

on_epoch_end(self, epoch, logs=None)

Вызывается в конце эпохи во время обучения.

Использование словаря logs

Словарь logs содержит значение потерь и все метрики в конце пакета пакета или эпохи. Пример включает потери и среднеквадратичную ошибку.

class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):

  def on_train_batch_end(self, batch, logs=None):
    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

  def on_test_batch_end(self, batch, logs=None):
    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

  def on_epoch_end(self, epoch, logs=None):
    print('Средние потери за эпоху {} равны {:7.2f}, а среднеквадратичная ошибка равна {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=3,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is   22.80.
For batch 1, loss is  522.04.
For batch 2, loss is  356.54.
For batch 3, loss is  269.76.
For batch 4, loss is  217.35.
Средние потери за эпоху 0 равны  217.35, а среднеквадратичная ошибка равна    8.57.
For batch 0, loss is    5.76.
For batch 1, loss is    6.34.
For batch 2, loss is    6.51.
For batch 3, loss is    6.25.
For batch 4, loss is    5.97.
Средние потери за эпоху 1 равны    5.97, а среднеквадратичная ошибка равна    2.07.
For batch 0, loss is    6.07.
For batch 1, loss is    5.37.
For batch 2, loss is    4.98.
For batch 3, loss is    5.20.
For batch 4, loss is    4.94.
Средние потери за эпоху 2 равны    4.94, а среднеквадратичная ошибка равна    1.81.

Аналогично, можно обеспечить колбеки в вызовах evaluate().

_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20,
          callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is    5.08.
For batch 1, loss is    4.46.
For batch 2, loss is    4.64.
For batch 3, loss is    4.61.
For batch 4, loss is    4.82.
For batch 5, loss is    4.70.
For batch 6, loss is    4.70.
For batch 7, loss is    4.63.
For batch 8, loss is    4.65.
For batch 9, loss is    4.80.
For batch 10, loss is    4.88.
For batch 11, loss is    4.95.
For batch 12, loss is    5.03.
For batch 13, loss is    5.19.
For batch 14, loss is    5.16.
For batch 15, loss is    5.15.
For batch 16, loss is    5.20.
For batch 17, loss is    5.23.
For batch 18, loss is    5.30.
For batch 19, loss is    5.28.

Примеры колбек-приложений Keras

Следующий раздел поможет вам в создании простых Callback приложений.

Ранняя остановка при минимальных потерях

Первый пример демонстрирует создание Callback который останавливает обучение Keras когда достигнут минимум потерь путем изменения аргумента model.stop_training (булево значение). Опционально, пользователь может использовать аргумент patience чтобы указать сколько эпох еще обучаться, перед остановкой.

tf.keras.callbacks.EarlyStopping обеспечиваает более полную и общую реализацию.

import numpy as np

class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback):
  """Остановить обучение, когда loss на минимуме, т.е. loss прекратил уменьшаться.

  Аргументы:
      patience: Количество эпох ожидания после достижения минимума. Если столько
      эпох нет улучшения, обучение останавливается.
  """

  def __init__(self, patience=0):
    super(EarlyStoppingAtMinLoss, self).__init__()

    self.patience = patience

    # best_weights для хранения весов на которых достигнут минимум потерь.
    self.best_weights = None

  def on_train_begin(self, logs=None):
    # Количество эпох за время которых потери не уменьшаются.
    self.wait = 0
    # Эпоха на которой остановилось обучение.
    self.stopped_epoch = 0
    # Инициализация лучшего значения равным бесконечности.
    self.best = np.Inf

  def on_epoch_end(self, epoch, logs=None):
    current = logs.get('loss')
    if np.less(current, self.best):
      self.best = current
      self.wait = 0
      # Записать лучшие веса если текущие результаты лучше (меньше).
      self.best_weights = self.model.get_weights()
    else:
      self.wait += 1
      if self.wait >= self.patience:
        self.stopped_epoch = epoch
        self.model.stop_training = True
        print('Восстановление весов моедли с конца лучшей эпохи.')
        self.model.set_weights(self.best_weights)

  def on_train_end(self, logs=None):
    if self.stopped_epoch > 0:
      print('Эпоха %05d: ранняя остановка' % (self.stopped_epoch + 1))
model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=30,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])
For batch 0, loss is   32.99.
For batch 1, loss is  438.37.
For batch 2, loss is  300.09.
For batch 3, loss is  228.12.
For batch 4, loss is  183.96.
Средние потери за эпоху 0 равны  183.96, а среднеквадратичная ошибка равна    8.20.
For batch 0, loss is    6.44.
For batch 1, loss is    6.46.
For batch 2, loss is    6.05.
For batch 3, loss is    5.64.
For batch 4, loss is    5.60.
Средние потери за эпоху 1 равны    5.60, а среднеквадратичная ошибка равна    1.93.
For batch 0, loss is    4.03.
For batch 1, loss is    3.92.
For batch 2, loss is    4.34.
For batch 3, loss is    4.51.
For batch 4, loss is    4.93.
Средние потери за эпоху 2 равны    4.93, а среднеквадратичная ошибка равна    1.80.
For batch 0, loss is    4.99.
For batch 1, loss is    5.38.
For batch 2, loss is    5.67.
For batch 3, loss is    6.36.
For batch 4, loss is    8.70.
Средние потери за эпоху 3 равны    8.70, а среднеквадратичная ошибка равна    2.35.
Восстановление весов моедли с конца лучшей эпохи.
Эпоха 00004: ранняя остановка

Планирование скорости обучения

Одна вещь которую обычно делают при обучении модели, это изменение скорости обучения по мере того как проходит больше эпох.

Замечание: это лишь реализация для примера см. callbacks.LearningRateScheduler и keras.optimizers.schedules для более общей реализации.

class LearningRateScheduler(tf.keras.callbacks.Callback):
  """Планировщик скорости обучения, устанавливающий скорость в соответствии с расписанием.

  Аргументы:
      schedule: функция которая получает на вход индекс эпохи
          (целое число, индексируемое с нулЯ 0) и текущую скорость обучения
          и возвращая новую скорость обучения на выходе (float).
  """

  def __init__(self, schedule):
    super(LearningRateScheduler, self).__init__()
    self.schedule = schedule

  def on_epoch_begin(self, epoch, logs=None):
    if not hasattr(self.model.optimizer, 'lr'):
      raise ValueError('Optimizer must have a "lr" attribute.')
    # Получаем текущую скорость обучения от оптимизатора модели.
    lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
    # Вызываем функцию расписания, чтобы получить запланированную скорость обучения.
    scheduled_lr = self.schedule(epoch, lr)
    # Установим значение обратно в оптимизатор до начала этой эпохи
    tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
    print('\nEpoch %05d: Learning rate is %6.4f.' % (epoch, scheduled_lr))
LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)
]

def lr_schedule(epoch, lr):
  """Вспомогательная функция для получения запланированной скорости обучения на основе порядкового номера эпохи."""
  if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
    return lr
  for i in range(len(LR_SCHEDULE)):
    if epoch == LR_SCHEDULE[i][0]:
      return LR_SCHEDULE[i][1]
  return lr

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=15,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])

Epoch 00000: Learning rate is 0.1000.
For batch 0, loss is   24.62.
For batch 1, loss is  449.40.
For batch 2, loss is  307.77.
For batch 3, loss is  232.41.
For batch 4, loss is  187.26.
Средние потери за эпоху 0 равны  187.26, а среднеквадратичная ошибка равна    7.83.

Epoch 00001: Learning rate is 0.1000.
For batch 0, loss is    9.33.
For batch 1, loss is    8.00.
For batch 2, loss is    6.92.
For batch 3, loss is    6.39.
For batch 4, loss is    6.15.
Средние потери за эпоху 1 равны    6.15, а среднеквадратичная ошибка равна    1.96.

Epoch 00002: Learning rate is 0.1000.
For batch 0, loss is    6.47.
For batch 1, loss is    5.36.
For batch 2, loss is    5.47.
For batch 3, loss is    5.17.
For batch 4, loss is    5.05.
Средние потери за эпоху 2 равны    5.05, а среднеквадратичная ошибка равна    1.83.

Epoch 00003: Learning rate is 0.0500.
For batch 0, loss is    6.05.
For batch 1, loss is    5.02.
For batch 2, loss is    5.06.
For batch 3, loss is    5.48.
For batch 4, loss is    5.68.
Средние потери за эпоху 3 равны    5.68, а среднеквадратичная ошибка равна    1.91.

Epoch 00004: Learning rate is 0.0500.
For batch 0, loss is    5.82.
For batch 1, loss is    5.00.
For batch 2, loss is    4.79.
For batch 3, loss is    4.64.
For batch 4, loss is    4.99.
Средние потери за эпоху 4 равны    4.99, а среднеквадратичная ошибка равна    1.78.

Epoch 00005: Learning rate is 0.0500.
For batch 0, loss is    7.55.
For batch 1, loss is    6.33.
For batch 2, loss is    6.12.
For batch 3, loss is    6.39.
For batch 4, loss is    6.41.
Средние потери за эпоху 5 равны    6.41, а среднеквадратичная ошибка равна    2.04.

Epoch 00006: Learning rate is 0.0100.
For batch 0, loss is    6.12.
For batch 1, loss is    5.39.
For batch 2, loss is    4.85.
For batch 3, loss is    4.67.
For batch 4, loss is    4.60.
Средние потери за эпоху 6 равны    4.60, а среднеквадратичная ошибка равна    1.73.

Epoch 00007: Learning rate is 0.0100.
For batch 0, loss is    3.96.
For batch 1, loss is    4.34.
For batch 2, loss is    4.35.
For batch 3, loss is    3.98.
For batch 4, loss is    3.93.
Средние потери за эпоху 7 равны    3.93, а среднеквадратичная ошибка равна    1.60.

Epoch 00008: Learning rate is 0.0100.
For batch 0, loss is    5.25.
For batch 1, loss is    4.37.
For batch 2, loss is    4.51.
For batch 3, loss is    4.50.
For batch 4, loss is    4.57.
Средние потери за эпоху 8 равны    4.57, а среднеквадратичная ошибка равна    1.74.

Epoch 00009: Learning rate is 0.0050.
For batch 0, loss is    3.88.
For batch 1, loss is    3.53.
For batch 2, loss is    3.41.
For batch 3, loss is    3.75.
For batch 4, loss is    3.65.
Средние потери за эпоху 9 равны    3.65, а среднеквадратичная ошибка равна    1.50.

Epoch 00010: Learning rate is 0.0050.
For batch 0, loss is    3.53.
For batch 1, loss is    3.96.
For batch 2, loss is    3.85.
For batch 3, loss is    3.94.
For batch 4, loss is    4.13.
Средние потери за эпоху 10 равны    4.13, а среднеквадратичная ошибка равна    1.62.

Epoch 00011: Learning rate is 0.0050.
For batch 0, loss is    4.54.
For batch 1, loss is    4.19.
For batch 2, loss is    4.19.
For batch 3, loss is    4.19.
For batch 4, loss is    4.26.
Средние потери за эпоху 11 равны    4.26, а среднеквадратичная ошибка равна    1.67.

Epoch 00012: Learning rate is 0.0010.
For batch 0, loss is    4.28.
For batch 1, loss is    4.27.
For batch 2, loss is    4.48.
For batch 3, loss is    4.60.
For batch 4, loss is    4.57.
Средние потери за эпоху 12 равны    4.57, а среднеквадратичная ошибка равна    1.66.

Epoch 00013: Learning rate is 0.0010.
For batch 0, loss is    4.18.
For batch 1, loss is    3.94.
For batch 2, loss is    4.14.
For batch 3, loss is    4.45.
For batch 4, loss is    4.21.
Средние потери за эпоху 13 равны    4.21, а среднеквадратичная ошибка равна    1.58.

Epoch 00014: Learning rate is 0.0010.
For batch 0, loss is    3.86.
For batch 1, loss is    3.94.
For batch 2, loss is    3.49.
For batch 3, loss is    3.88.
For batch 4, loss is    4.40.
Средние потери за эпоху 14 равны    4.40, а среднеквадратичная ошибка равна    1.68.

Стандартные колбеки Keras

Не забудьте проверить существующие колбеки Keras посетив документацию API. Приложения включающие логирование в CSV, сохранение модели, визуализацию на TensorBoard и многое другое.