이 페이지는 Cloud Translation API를 통해 번역되었습니다.
Switch to English

자신 만의 콜백 작성

TensorFlow.org에서보기 Google Colab에서 실행 GitHub에서 소스보기 노트북 다운로드

소개

콜백은 학습, 평가 또는 추론 중에 Keras 모델의 동작을 사용자 지정하는 강력한 도구입니다. 예를 들어 tf.keras.callbacks.TensorBoard 를 사용하여 학습 진행 상황과 결과를 시각화하는 tf.keras.callbacks.TensorBoard 또는 tf.keras.callbacks.ModelCheckpoint 중에 모델을 주기적으로 저장하는 tf.keras.callbacks.ModelCheckpoint 가 있습니다.

이 가이드에서는 Keras 콜백이 무엇인지, 무엇을 할 수 있는지, 어떻게 직접 빌드 할 수 있는지에 대해 배웁니다. 시작하는 데 도움이되는 간단한 콜백 애플리케이션의 몇 가지 데모를 제공합니다.

설정

import tensorflow as tf
from tensorflow import keras

Keras 콜백 개요

모든 콜백은 keras.callbacks.Callback 클래스를 하위 클래스로 분류하고 학습, 테스트 및 예측의 다양한 단계에서 호출되는 일련의 메서드를 재정의합니다. 콜백은 학습 중에 모델의 내부 상태 및 통계에 대한보기를 얻는 데 유용합니다.

다음 모델 메서드에 콜백 목록 (키워드 인수 callbacks )을 전달할 수 있습니다.

콜백 메서드 개요

글로벌 방법

on_(train|test|predict)_begin(self, logs=None)

fit / evaluate / predict 이 시작될 때 호출됩니다.

on_(train|test|predict)_end(self, logs=None)

fit / evaluate / predict 이 끝날 때 호출됩니다.

훈련 / 테스트 / 예측을위한 배치 수준 방법

on_(train|test|predict)_batch_begin(self, batch, logs=None)

훈련 / 테스트 / 예측 중 배치를 처리하기 직전에 호출됩니다.

on_(train|test|predict)_batch_end(self, batch, logs=None)

학습 / 테스트 / 배치 예측이 끝날 때 호출됩니다. 이 방법에서 logs 는 메트릭 결과를 포함하는 사전입니다.

Epoch 수준 방법 (교육 전용)

on_epoch_begin(self, epoch, logs=None)

훈련 중 에포크가 시작될 때 호출됩니다.

on_epoch_end(self, epoch, logs=None)

훈련 중 에포크가 끝날 때 호출됩니다.

기본 예

구체적인 예를 살펴 보겠습니다. 시작하려면 tensorflow를 가져오고 간단한 Sequential Keras 모델을 정의하겠습니다.

# Define the Keras model to add callbacks to
def get_model():
    model = keras.Sequential()
    model.add(keras.layers.Dense(1, input_dim=784))
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
        loss="mean_squared_error",
        metrics=["mean_absolute_error"],
    )
    return model

그런 다음 Keras 데이터 세트 API에서 학습 및 테스트를 위해 MNIST 데이터를로드합니다.

# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0

# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]

이제 다음을 기록하는 간단한 사용자 지정 콜백을 정의합니다.

  • fit / evaluate / predict 시작 및 종료시기
  • 각 시대가 시작되고 끝날 때
  • 각 교육 배치 시작 및 종료시기
  • 각 평가 (테스트) 배치 시작 및 종료 시점
  • 각 추론 (예측) 배치 시작 및 종료 시점
class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

시도해 보겠습니다.

model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=1,
    verbose=0,
    validation_split=0.5,
    callbacks=[CustomCallback()],
)

res = model.evaluate(
    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)

res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Starting training; got log keys: []
Start epoch 0 of training; got log keys: []
...Training: start of batch 0; got log keys: []
...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 1; got log keys: []
...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 2; got log keys: []
...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 3; got log keys: []
...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
Start testing; got log keys: []
...Evaluating: start of batch 0; got log keys: []
...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 1; got log keys: []
...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 2; got log keys: []
...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 3; got log keys: []
...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
Stop testing; got log keys: ['loss', 'mean_absolute_error']
End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Start testing; got log keys: []
...Evaluating: start of batch 0; got log keys: []
...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 1; got log keys: []
...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 2; got log keys: []
...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 3; got log keys: []
...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 4; got log keys: []
...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 5; got log keys: []
...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 6; got log keys: []
...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 7; got log keys: []
...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error']
Stop testing; got log keys: ['loss', 'mean_absolute_error']
Start predicting; got log keys: []
...Predicting: start of batch 0; got log keys: []
...Predicting: end of batch 0; got log keys: ['outputs']
...Predicting: start of batch 1; got log keys: []
...Predicting: end of batch 1; got log keys: ['outputs']
...Predicting: start of batch 2; got log keys: []
...Predicting: end of batch 2; got log keys: ['outputs']
...Predicting: start of batch 3; got log keys: []
...Predicting: end of batch 3; got log keys: ['outputs']
...Predicting: start of batch 4; got log keys: []
...Predicting: end of batch 4; got log keys: ['outputs']
...Predicting: start of batch 5; got log keys: []
...Predicting: end of batch 5; got log keys: ['outputs']
...Predicting: start of batch 6; got log keys: []
...Predicting: end of batch 6; got log keys: ['outputs']
...Predicting: start of batch 7; got log keys: []
...Predicting: end of batch 7; got log keys: ['outputs']
Stop predicting; got log keys: []

logs 사전 사용

logs 딕셔너리는 손실 값과 배치 또는 에포크가 끝날 때의 모든 측정 항목을 포함합니다. 예에는 손실 및 평균 절대 오차가 포함됩니다.

class LossAndErrorPrintingCallback(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

    def on_test_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

    def on_epoch_end(self, epoch, logs=None):
        print(
            "The average loss for epoch {} is {:7.2f} "
            "and mean absolute error is {:7.2f}.".format(
                epoch, logs["loss"], logs["mean_absolute_error"]
            )
        )


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=2,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback()],
)

res = model.evaluate(
    x_test,
    y_test,
    batch_size=128,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback()],
)
For batch 0, loss is   25.20.
For batch 1, loss is  433.41.
For batch 2, loss is  296.65.
For batch 3, loss is  225.30.
For batch 4, loss is  181.68.
For batch 5, loss is  152.45.
For batch 6, loss is  131.45.
For batch 7, loss is  118.27.
The average loss for epoch 0 is  118.27 and mean absolute error is    5.86.
For batch 0, loss is    4.57.
For batch 1, loss is    4.56.
For batch 2, loss is    4.62.
For batch 3, loss is    4.57.
For batch 4, loss is    4.56.
For batch 5, loss is    4.63.
For batch 6, loss is    4.52.
For batch 7, loss is    4.44.
The average loss for epoch 1 is    4.44 and mean absolute error is    1.70.
For batch 0, loss is    5.00.
For batch 1, loss is    4.46.
For batch 2, loss is    4.60.
For batch 3, loss is    4.53.
For batch 4, loss is    4.64.
For batch 5, loss is    4.60.
For batch 6, loss is    4.52.
For batch 7, loss is    4.45.

self.model 속성 사용

메서드 중 하나가 호출 될 때 로그 정보를 수신하는 것 외에도 콜백은 현재 학습 / 평가 / 추론 라운드와 관련된 모델 인 self.model 있습니다.

콜백에서 self.model 로 할 수있는 몇 가지 작업은 다음과 같습니다.

  • 훈련을 즉시 중단하려면 self.model.stop_training = True 로 설정하십시오.
  • self.model.optimizer 와 같은 최적화 프로그램의 하이퍼 파라미터 ( self.model.optimizer 사용 가능)를 self.model.optimizer.learning_rate 합니다.
  • 주기 간격으로 모델을 저장합니다.
  • 훈련 중에 온 전성 검사로 사용하기 위해 각 model.predict() 가 끝날 때 몇 가지 테스트 샘플에 model.predict() 의 출력을 기록합니다.
  • 시간이 지남에 따라 모델이 학습하는 내용을 모니터링하기 위해 각 세대가 끝날 때 중간 기능의 시각화를 추출합니다.
  • 기타

몇 가지 예를 통해 실제 작동 방식을 살펴 보겠습니다.

Keras 콜백 애플리케이션의 예

최소한의 손실로 조기 중지

이 첫 번째 예제는 self.model.stop_training (boolean) 속성을 설정하여 최소 손실에 도달했을 때 훈련을 중지하는 Callback 생성을 보여줍니다. 선택적으로 로컬 최소값에 도달 한 후 중지하기 전에 대기해야하는 Epoch 수를 지정하기 위해 인수 patience 을 제공 할 수 있습니다.

tf.keras.callbacks.EarlyStopping 은보다 완전하고 일반적인 구현을 제공합니다.

import numpy as np


class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    """Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
      number of no improvement, training stops.
  """

    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("loss")
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=30,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
For batch 0, loss is   23.44.
For batch 1, loss is  406.69.
For batch 2, loss is  279.72.
For batch 3, loss is  212.43.
For batch 4, loss is  171.98.
The average loss for epoch 0 is  171.98 and mean absolute error is    7.90.
For batch 0, loss is    4.90.
For batch 1, loss is    5.80.
For batch 2, loss is    6.08.
For batch 3, loss is    5.92.
For batch 4, loss is    5.71.
The average loss for epoch 1 is    5.71 and mean absolute error is    1.94.
For batch 0, loss is    5.28.
For batch 1, loss is    4.79.
For batch 2, loss is    4.87.
For batch 3, loss is    5.29.
For batch 4, loss is    5.65.
The average loss for epoch 2 is    5.65 and mean absolute error is    1.95.
For batch 0, loss is    8.66.
For batch 1, loss is   12.04.
For batch 2, loss is   15.36.
For batch 3, loss is   23.19.
For batch 4, loss is   37.54.
The average loss for epoch 3 is   37.54 and mean absolute error is    5.09.
Restoring model weights from the end of the best epoch.
Epoch 00004: early stopping

<tensorflow.python.keras.callbacks.History at 0x7fa82a016ac8>

학습률 스케줄링

이 예에서는 사용자 지정 콜백을 사용하여 교육 과정 중에 최적화 프로그램의 학습률을 동적으로 변경하는 방법을 보여줍니다.

보다 일반적인 구현은 callbacks.LearningRateScheduler 를 참조하십시오.

class CustomLearningRateScheduler(keras.callbacks.Callback):
    """Learning rate scheduler which sets the learning rate according to schedule.

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  """

    def __init__(self, schedule):
        super(CustomLearningRateScheduler, self).__init__()
        self.schedule = schedule

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(epoch, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))


LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05),
    (6, 0.01),
    (9, 0.005),
    (12, 0.001),
]


def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=15,
    verbose=0,
    callbacks=[
        LossAndErrorPrintingCallback(),
        CustomLearningRateScheduler(lr_schedule),
    ],
)

Epoch 00000: Learning rate is 0.1000.
For batch 0, loss is   28.12.
For batch 1, loss is  486.39.
For batch 2, loss is  333.25.
For batch 3, loss is  251.94.
For batch 4, loss is  202.90.
The average loss for epoch 0 is  202.90 and mean absolute error is    8.33.

Epoch 00001: Learning rate is 0.1000.
For batch 0, loss is    6.74.
For batch 1, loss is    5.65.
For batch 2, loss is    5.89.
For batch 3, loss is    5.58.
For batch 4, loss is    5.61.
The average loss for epoch 1 is    5.61 and mean absolute error is    1.93.

Epoch 00002: Learning rate is 0.1000.
For batch 0, loss is    4.10.
For batch 1, loss is    3.94.
For batch 2, loss is    4.34.
For batch 3, loss is    4.33.
For batch 4, loss is    4.79.
The average loss for epoch 2 is    4.79 and mean absolute error is    1.74.

Epoch 00003: Learning rate is 0.0500.
For batch 0, loss is    6.13.
For batch 1, loss is    4.77.
For batch 2, loss is    5.06.
For batch 3, loss is    4.60.
For batch 4, loss is    4.41.
The average loss for epoch 3 is    4.41 and mean absolute error is    1.67.

Epoch 00004: Learning rate is 0.0500.
For batch 0, loss is    4.81.
For batch 1, loss is    4.73.
For batch 2, loss is    4.55.
For batch 3, loss is    4.63.
For batch 4, loss is    4.58.
The average loss for epoch 4 is    4.58 and mean absolute error is    1.70.

Epoch 00005: Learning rate is 0.0500.
For batch 0, loss is    4.45.
For batch 1, loss is    4.62.
For batch 2, loss is    4.30.
For batch 3, loss is    4.67.
For batch 4, loss is    5.11.
The average loss for epoch 5 is    5.11 and mean absolute error is    1.85.

Epoch 00006: Learning rate is 0.0100.
For batch 0, loss is   11.94.
For batch 1, loss is    9.66.
For batch 2, loss is    7.27.
For batch 3, loss is    5.80.
For batch 4, loss is    5.47.
The average loss for epoch 6 is    5.47 and mean absolute error is    1.85.

Epoch 00007: Learning rate is 0.0100.
For batch 0, loss is    4.25.
For batch 1, loss is    3.60.
For batch 2, loss is    4.19.
For batch 3, loss is    3.94.
For batch 4, loss is    3.74.
The average loss for epoch 7 is    3.74 and mean absolute error is    1.46.

Epoch 00008: Learning rate is 0.0100.
For batch 0, loss is    3.72.
For batch 1, loss is    3.55.
For batch 2, loss is    3.54.
For batch 3, loss is    3.60.
For batch 4, loss is    3.57.
The average loss for epoch 8 is    3.57 and mean absolute error is    1.49.

Epoch 00009: Learning rate is 0.0050.
For batch 0, loss is    3.55.
For batch 1, loss is    3.74.
For batch 2, loss is    3.68.
For batch 3, loss is    3.76.
For batch 4, loss is    3.57.
The average loss for epoch 9 is    3.57 and mean absolute error is    1.46.

Epoch 00010: Learning rate is 0.0050.
For batch 0, loss is    4.07.
For batch 1, loss is    3.84.
For batch 2, loss is    3.73.
For batch 3, loss is    3.46.
For batch 4, loss is    3.53.
The average loss for epoch 10 is    3.53 and mean absolute error is    1.43.

Epoch 00011: Learning rate is 0.0050.
For batch 0, loss is    3.49.
For batch 1, loss is    2.75.
For batch 2, loss is    2.67.
For batch 3, loss is    3.18.
For batch 4, loss is    3.47.
The average loss for epoch 11 is    3.47 and mean absolute error is    1.41.

Epoch 00012: Learning rate is 0.0010.
For batch 0, loss is    2.99.
For batch 1, loss is    3.01.
For batch 2, loss is    3.12.
For batch 3, loss is    3.07.
For batch 4, loss is    2.87.
The average loss for epoch 12 is    2.87 and mean absolute error is    1.35.

Epoch 00013: Learning rate is 0.0010.
For batch 0, loss is    2.93.
For batch 1, loss is    3.15.
For batch 2, loss is    3.41.
For batch 3, loss is    3.44.
For batch 4, loss is    3.44.
The average loss for epoch 13 is    3.44 and mean absolute error is    1.44.

Epoch 00014: Learning rate is 0.0010.
For batch 0, loss is    3.07.
For batch 1, loss is    3.00.
For batch 2, loss is    2.68.
For batch 3, loss is    3.21.
For batch 4, loss is    3.16.
The average loss for epoch 14 is    3.16 and mean absolute error is    1.36.

<tensorflow.python.keras.callbacks.History at 0x7fa829f018d0>

내장 Keras 콜백

API 문서 를 읽고 기존 Keras 콜백을 확인하십시오. 애플리케이션에는 CSV 로깅, 모델 저장, TensorBoard에서 메트릭 시각화 등이 포함됩니다!