今日のローカルTensorFlowEverywhereイベントの出欠確認!

コールバックを書く

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示{ Download notebook

はじめに

コールバックは、トレーニング、評価、推論の間に Keras モデルの動作をカスタマイズするための強力なツールです。例には、TensorBoard でトレーニングの進捗状況や結果を可視化できる tf.keras.callbacks.TensorBoard や、トレーニング中にモデルを定期的に保存できる tf.keras.callbacks.ModelCheckpoint などを含みます。

このガイドでは、Keras コールバックとは何か、それができること、そして独自のコールバックを構築する方法を学ぶことができます。まずは、簡単なコールバックアプリケーションのデモをいくつか紹介します。

セットアップ

import tensorflow as tf
from tensorflow import keras

Keras コールバックの概要

全てのコールバックは keras.callbacks.Callbacks.Callback クラスをサブクラス化し、トレーニング、テスト、予測のさまざまな段階で呼び出される一連のメソッドをオーバーライドします。コールバックは、トレーニング中にモデルの内部状態や統計上のビューを取得するのに有用です。

以下のモデルメソッドには、(キーワード引数 callbacks として)コールバックのリストを渡すことができます。

コールバックメソッドの概要

グローバルメソッド

on_(train|test|predict)_begin(self, logs=None)

fit/evaluate/predict の先頭で呼び出されます。

on_(train|test|predict)_end(self, logs=None)

fit/evaluate/predict の最後に呼び出されます。

トレーニング/テスト/予測のためのバッチレベルのメソッド

on_(train|test|predict)_batch_begin(self, batch, logs=None)

トレーニング/テスト/予測中に、バッチを処理する直前に呼び出されます。

on_(train|test|predict)_batch_end(self, batch, logs=None)

バッチのトレーニング/テスト/予測の終了時に呼び出されます。このメソッド内では、logs はメトリクスの結果を含むディクショナリです。

エポックレベルのメソッド(トレーニングのみ)

on_epoch_begin(self, epoch, logs=None)

トレーニング中に、エポックの最初に呼び出されます。

on_epoch_end(self, epoch, logs=None)

トレーニング中、エポックの最後に呼び出されます。

基本的な例

具体的な例を見てみましょう。まず最初に、TensorFlow をインポートして単純な Sequential Keras モデルを定義してみます。

# Define the Keras model to add callbacks to
def get_model():
    model = keras.Sequential()
    model.add(keras.layers.Dense(1, input_dim=784))
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
        loss="mean_squared_error",
        metrics=["mean_absolute_error"],
    )
    return model

次に、Keras データセット API からトレーニングとテスト用の MNIST データを読み込みます。

# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0

# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]

今度は、以下のログを記録する単純なカスタムコールバックを定義します。

  • When fit/evaluate/predict starts & ends
  • When each epoch starts & ends
  • 各トレーニングバッチの開始時と終了時
  • 各評価(テスト)バッチの開始時と終了時
  • 各推論(予測)バッチの開始時と終了時
class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

試してみましょう。

model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=1,
    verbose=0,
    validation_split=0.5,
    callbacks=[CustomCallback()],
)

res = model.evaluate(
    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)

res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Starting training; got log keys: []
Start epoch 0 of training; got log keys: []
...Training: start of batch 0; got log keys: []
...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 1; got log keys: []
...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 2; got log keys: []
...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 3; got log keys: []
...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
Start testing; got log keys: []
...Evaluating: start of batch 0; got log keys: []
...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 1; got log keys: []
...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 2; got log keys: []
...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 3; got log keys: []
...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
Stop testing; got log keys: ['loss', 'mean_absolute_error']
End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Start testing; got log keys: []
...Evaluating: start of batch 0; got log keys: []
...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 1; got log keys: []
...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 2; got log keys: []
...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 3; got log keys: []
...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 4; got log keys: []
...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 5; got log keys: []
...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 6; got log keys: []
...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 7; got log keys: []
...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error']
Stop testing; got log keys: ['loss', 'mean_absolute_error']
Start predicting; got log keys: []
...Predicting: start of batch 0; got log keys: []
...Predicting: end of batch 0; got log keys: ['outputs']
...Predicting: start of batch 1; got log keys: []
...Predicting: end of batch 1; got log keys: ['outputs']
...Predicting: start of batch 2; got log keys: []
...Predicting: end of batch 2; got log keys: ['outputs']
...Predicting: start of batch 3; got log keys: []
...Predicting: end of batch 3; got log keys: ['outputs']
...Predicting: start of batch 4; got log keys: []
...Predicting: end of batch 4; got log keys: ['outputs']
...Predicting: start of batch 5; got log keys: []
...Predicting: end of batch 5; got log keys: ['outputs']
...Predicting: start of batch 6; got log keys: []
...Predicting: end of batch 6; got log keys: ['outputs']
...Predicting: start of batch 7; got log keys: []
...Predicting: end of batch 7; got log keys: ['outputs']
Stop predicting; got log keys: []

logs ディクショナリを使用する

logs ディクショナリは、バッチまたはエポックの最後の損失値と全てのメトリクスを含みます。次の例は、損失値と平均絶対誤差を含んでいます。

class LossAndErrorPrintingCallback(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

    def on_test_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

    def on_epoch_end(self, epoch, logs=None):
        print(
            "The average loss for epoch {} is {:7.2f} "
            "and mean absolute error is {:7.2f}.".format(
                epoch, logs["loss"], logs["mean_absolute_error"]
            )
        )


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=2,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback()],
)

res = model.evaluate(
    x_test,
    y_test,
    batch_size=128,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback()],
)
For batch 0, loss is   31.57.
For batch 1, loss is  452.68.
For batch 2, loss is  311.45.
For batch 3, loss is  236.22.
For batch 4, loss is  190.69.
For batch 5, loss is  159.97.
For batch 6, loss is  138.05.
For batch 7, loss is  124.27.
The average loss for epoch 0 is  124.27 and mean absolute error is    6.12.
For batch 0, loss is    4.65.
For batch 1, loss is    4.74.
For batch 2, loss is    5.01.
For batch 3, loss is    4.82.
For batch 4, loss is    4.79.
For batch 5, loss is    4.70.
For batch 6, loss is    4.66.
For batch 7, loss is    4.56.
The average loss for epoch 1 is    4.56 and mean absolute error is    1.72.
For batch 0, loss is    4.67.
For batch 1, loss is    4.24.
For batch 2, loss is    4.29.
For batch 3, loss is    4.22.
For batch 4, loss is    4.36.
For batch 5, loss is    4.35.
For batch 6, loss is    4.30.
For batch 7, loss is    4.25.

self.model 属性を使用する

コールバックは、そのメソッドの 1 つが呼び出された時にログ情報を受け取ることに加え、現在のトレーニング/評価/推論のラウンドに関連付けられたモデルに、self.model でアクセスすることができます。

コールバックで self.model を使用してできることを幾つか次に示します。

  • self.model.stop_training = True を設定して直ちにトレーニングを中断する。
  • self.model.optimizer.learning_rate など、オプティマイザ(self.model.optimizer として使用可能)のハイパーパラメータを変化させる。
  • 一定間隔でモデルを保存する。
  • 各エポックの終了時に幾つかのテストサンプルの model.predict() の出力を記録し、トレーニング中にサ二ティーチェックとして使用する。
  • 各エポックの終了時に中間特徴の可視化を抽出して、モデルが何を学習しているかを経時的に監視する。
  • など

これを確認するために、2 つの例で見てみましょう。

Keras コールバックアプリケーションの例

最小損失で Early stopping する

この最初の例は、属性 self.model.stop_training(ブール)を設定して、損失の最小値に達した時点でトレーニングを停止する Callback を作成しています。オプションで、ローカル最小値に到達した後、実際に停止するまでに幾つのエポックを待つべきか、引数 patience で指定することが可能です。

tf.keras.callbacks.EarlyStopping は、より完全で一般的な実装を提供します。

import numpy as np


class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    """Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
      number of no improvement, training stops.
  """

    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("loss")
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=30,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
For batch 0, loss is   31.99.
For batch 1, loss is  456.49.
For batch 2, loss is  316.17.
For batch 3, loss is  240.72.
For batch 4, loss is  194.07.
The average loss for epoch 0 is  194.07 and mean absolute error is    8.71.
For batch 0, loss is    5.99.
For batch 1, loss is    5.77.
For batch 2, loss is    5.26.
For batch 3, loss is    5.28.
For batch 4, loss is    5.48.
The average loss for epoch 1 is    5.48 and mean absolute error is    1.90.
For batch 0, loss is    6.50.
For batch 1, loss is    4.84.
For batch 2, loss is    4.86.
For batch 3, loss is    4.95.
For batch 4, loss is    4.73.
The average loss for epoch 2 is    4.73 and mean absolute error is    1.76.
For batch 0, loss is    4.44.
For batch 1, loss is    5.31.
For batch 2, loss is    5.56.
For batch 3, loss is    5.97.
For batch 4, loss is    6.98.
The average loss for epoch 3 is    6.98 and mean absolute error is    2.10.
Restoring model weights from the end of the best epoch.
Epoch 00004: early stopping

<tensorflow.python.keras.callbacks.History at 0x7f5fd019fb00>

学習率をスケジューリングする

この例では、トレーニングの過程でカスタムコールバックを使用して、オプティマイザの学習率を動的に変更する方法を示します。

より一般的な実装については、callbacks.LearningRateScheduler をご覧ください。

class CustomLearningRateScheduler(keras.callbacks.Callback):
    """Learning rate scheduler which sets the learning rate according to schedule.

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  """

    def __init__(self, schedule):
        super(CustomLearningRateScheduler, self).__init__()
        self.schedule = schedule

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(epoch, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))


LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05),
    (6, 0.01),
    (9, 0.005),
    (12, 0.001),
]


def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=15,
    verbose=0,
    callbacks=[
        LossAndErrorPrintingCallback(),
        CustomLearningRateScheduler(lr_schedule),
    ],
)

Epoch 00000: Learning rate is 0.1000.
For batch 0, loss is   33.29.
For batch 1, loss is  387.77.
For batch 2, loss is  268.60.
For batch 3, loss is  205.39.
For batch 4, loss is  166.73.
The average loss for epoch 0 is  166.73 and mean absolute error is    8.26.

Epoch 00001: Learning rate is 0.1000.
For batch 0, loss is    6.32.
For batch 1, loss is    6.86.
For batch 2, loss is    6.77.
For batch 3, loss is    6.16.
For batch 4, loss is    6.01.
The average loss for epoch 1 is    6.01 and mean absolute error is    2.01.

Epoch 00002: Learning rate is 0.1000.
For batch 0, loss is    6.02.
For batch 1, loss is    6.14.
For batch 2, loss is    6.53.
For batch 3, loss is    6.60.
For batch 4, loss is    7.06.
The average loss for epoch 2 is    7.06 and mean absolute error is    2.24.

Epoch 00003: Learning rate is 0.0500.
For batch 0, loss is   22.22.
For batch 1, loss is   12.58.
For batch 2, loss is    9.46.
For batch 3, loss is    8.06.
For batch 4, loss is    7.24.
The average loss for epoch 3 is    7.24 and mean absolute error is    2.05.

Epoch 00004: Learning rate is 0.0500.
For batch 0, loss is    4.29.
For batch 1, loss is    3.73.
For batch 2, loss is    4.17.
For batch 3, loss is    4.00.
For batch 4, loss is    4.07.
The average loss for epoch 4 is    4.07 and mean absolute error is    1.55.

Epoch 00005: Learning rate is 0.0500.
For batch 0, loss is    4.19.
For batch 1, loss is    4.38.
For batch 2, loss is    4.37.
For batch 3, loss is    4.78.
For batch 4, loss is    6.04.
The average loss for epoch 5 is    6.04 and mean absolute error is    1.97.

Epoch 00006: Learning rate is 0.0100.
For batch 0, loss is   18.12.
For batch 1, loss is   13.69.
For batch 2, loss is   10.70.
For batch 3, loss is    8.68.
For batch 4, loss is    7.65.
The average loss for epoch 6 is    7.65 and mean absolute error is    2.21.

Epoch 00007: Learning rate is 0.0100.
For batch 0, loss is    4.44.
For batch 1, loss is    3.88.
For batch 2, loss is    3.77.
For batch 3, loss is    3.73.
For batch 4, loss is    3.73.
The average loss for epoch 7 is    3.73 and mean absolute error is    1.50.

Epoch 00008: Learning rate is 0.0100.
For batch 0, loss is    3.97.
For batch 1, loss is    4.00.
For batch 2, loss is    3.94.
For batch 3, loss is    3.86.
For batch 4, loss is    3.75.
The average loss for epoch 8 is    3.75 and mean absolute error is    1.49.

Epoch 00009: Learning rate is 0.0050.
For batch 0, loss is    2.06.
For batch 1, loss is    2.53.
For batch 2, loss is    2.68.
For batch 3, loss is    2.75.
For batch 4, loss is    3.13.
The average loss for epoch 9 is    3.13 and mean absolute error is    1.39.

Epoch 00010: Learning rate is 0.0050.
For batch 0, loss is    4.04.
For batch 1, loss is    3.53.
For batch 2, loss is    3.26.
For batch 3, loss is    2.98.
For batch 4, loss is    2.97.
The average loss for epoch 10 is    2.97 and mean absolute error is    1.35.

Epoch 00011: Learning rate is 0.0050.
For batch 0, loss is    4.37.
For batch 1, loss is    3.51.
For batch 2, loss is    3.39.
For batch 3, loss is    3.55.
For batch 4, loss is    3.48.
The average loss for epoch 11 is    3.48 and mean absolute error is    1.46.

Epoch 00012: Learning rate is 0.0010.
For batch 0, loss is    3.20.
For batch 1, loss is    3.12.
For batch 2, loss is    3.21.
For batch 3, loss is    3.13.
For batch 4, loss is    3.41.
The average loss for epoch 12 is    3.41 and mean absolute error is    1.45.

Epoch 00013: Learning rate is 0.0010.
For batch 0, loss is    2.97.
For batch 1, loss is    3.40.
For batch 2, loss is    3.36.
For batch 3, loss is    3.19.
For batch 4, loss is    3.33.
The average loss for epoch 13 is    3.33 and mean absolute error is    1.41.

Epoch 00014: Learning rate is 0.0010.
For batch 0, loss is    3.69.
For batch 1, loss is    2.83.
For batch 2, loss is    3.09.
For batch 3, loss is    3.07.
For batch 4, loss is    3.07.
The average loss for epoch 14 is    3.07 and mean absolute error is    1.37.

<tensorflow.python.keras.callbacks.History at 0x7f5fd01f7630>

組み込みの Keras コールバック

既存の Keras コールバックについては、API ドキュメントを読んで必ず確認してください。アプリケーションには、CSV へのロギング、モデルの保存、TensorBoard でのメトリクスの可視化、その他多数があります。