![]() | ![]() | ![]() | ![]() |
Введение
Обратный вызов - это мощный инструмент для настройки поведения модели Keras во время обучения, оценки или вывода. Примеры включают tf.keras.callbacks.TensorBoard
визуализировать ход подготовки и результаты с TensorBoard или tf.keras.callbacks.ModelCheckpoint
периодически сохранить вашу модель во время тренировки.
В этом руководстве вы узнаете, что такое обратный вызов Keras, что он может делать и как вы можете создать свой собственный. Для начала мы предлагаем несколько демонстраций простых приложений обратного вызова.
Настраивать
import tensorflow as tf
from tensorflow import keras
Обзор обратных вызовов Keras
Все обратные вызовы подкласса keras.callbacks.Callback
класс и переопределить набор методов , вызываемых на различных этапах подготовки, тестирования и предсказывающих. Обратные вызовы полезны, чтобы получить представление о внутренних состояниях и статистике модели во время обучения.
Вы можете передать список обратных вызовов (как ключевой слово аргумента callbacks
) для следующих модельных методов:
Обзор методов обратного вызова
Глобальные методы
on_(train|test|predict)_begin(self, logs=None)
Вызывается в начале fit
/ evaluate
/ predict
.
on_(train|test|predict)_end(self, logs=None)
Вызывается в конце fit
/ evaluate
/ predict
.
Методы пакетного уровня для обучения / тестирования / прогнозирования
on_(train|test|predict)_batch_begin(self, batch, logs=None)
Вызывается прямо перед обработкой пакета во время обучения / тестирования / прогнозирования.
on_(train|test|predict)_batch_end(self, batch, logs=None)
Вызывается в конце обучения / тестирования / прогнозирования партии. В рамках этого метода, logs
является ДИКТ , содержащий результаты метрик.
Методы эпохального уровня (только обучение)
on_epoch_begin(self, epoch, logs=None)
Вызывается в начале эпохи во время тренировки.
on_epoch_end(self, epoch, logs=None)
Вызывается в конце эпохи во время тренировки.
Базовый пример
Давайте посмотрим на конкретный пример. Для начала давайте импортируем тензорный поток и определим простую модель Sequential Keras:
# Define the Keras model to add callbacks to
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1, input_dim=784))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model
Затем загрузите данные MNIST для обучения и тестирования из API наборов данных Keras:
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
Теперь определите простой настраиваемый обратный вызов, который регистрирует:
- Когда
fit
/evaluate
/predict
старты & концы - Когда каждая эпоха начинается и заканчивается
- Когда начинается и заканчивается каждая тренировочная партия
- Когда начинается и заканчивается каждая оценочная (тестовая) партия
- Когда начинается и заканчивается каждый пакет вывода (предсказания)
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))
def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))
def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))
def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
Давайте попробуем:
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=1,
verbose=0,
validation_split=0.5,
callbacks=[CustomCallback()],
)
res = model.evaluate(
x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)
res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Starting training; got log keys: [] Start epoch 0 of training; got log keys: [] ...Training: start of batch 0; got log keys: [] ...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 1; got log keys: [] ...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 2; got log keys: [] ...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 3; got log keys: [] ...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] Start testing; got log keys: [] ...Evaluating: start of batch 0; got log keys: [] ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 1; got log keys: [] ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 2; got log keys: [] ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 3; got log keys: [] ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] Stop testing; got log keys: ['loss', 'mean_absolute_error'] End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error'] Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error'] Start testing; got log keys: [] ...Evaluating: start of batch 0; got log keys: [] ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 1; got log keys: [] ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 2; got log keys: [] ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 3; got log keys: [] ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 4; got log keys: [] ...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 5; got log keys: [] ...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 6; got log keys: [] ...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 7; got log keys: [] ...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error'] Stop testing; got log keys: ['loss', 'mean_absolute_error'] Start predicting; got log keys: [] ...Predicting: start of batch 0; got log keys: [] ...Predicting: end of batch 0; got log keys: ['outputs'] ...Predicting: start of batch 1; got log keys: [] ...Predicting: end of batch 1; got log keys: ['outputs'] ...Predicting: start of batch 2; got log keys: [] ...Predicting: end of batch 2; got log keys: ['outputs'] ...Predicting: start of batch 3; got log keys: [] ...Predicting: end of batch 3; got log keys: ['outputs'] ...Predicting: start of batch 4; got log keys: [] ...Predicting: end of batch 4; got log keys: ['outputs'] ...Predicting: start of batch 5; got log keys: [] ...Predicting: end of batch 5; got log keys: ['outputs'] ...Predicting: start of batch 6; got log keys: [] ...Predicting: end of batch 6; got log keys: ['outputs'] ...Predicting: start of batch 7; got log keys: [] ...Predicting: end of batch 7; got log keys: ['outputs'] Stop predicting; got log keys: []
Использование logs
Dict
logs
ДИКТ содержит значение потери, и все показатели в конце партии или эпохи. Пример включает потерю и среднюю абсолютную ошибку.
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print(
"Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
)
def on_test_batch_end(self, batch, logs=None):
print(
"Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
)
def on_epoch_end(self, epoch, logs=None):
print(
"The average loss for epoch {} is {:7.2f} "
"and mean absolute error is {:7.2f}.".format(
epoch, logs["loss"], logs["mean_absolute_error"]
)
)
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=2,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
res = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
Up to batch 0, the average loss is 30.79. Up to batch 1, the average loss is 459.11. Up to batch 2, the average loss is 314.68. Up to batch 3, the average loss is 237.97. Up to batch 4, the average loss is 191.76. Up to batch 5, the average loss is 160.95. Up to batch 6, the average loss is 138.74. Up to batch 7, the average loss is 124.85. The average loss for epoch 0 is 124.85 and mean absolute error is 6.00. Up to batch 0, the average loss is 5.13. Up to batch 1, the average loss is 4.66. Up to batch 2, the average loss is 4.71. Up to batch 3, the average loss is 4.66. Up to batch 4, the average loss is 4.69. Up to batch 5, the average loss is 4.56. Up to batch 6, the average loss is 4.77. Up to batch 7, the average loss is 4.77. The average loss for epoch 1 is 4.77 and mean absolute error is 1.75. Up to batch 0, the average loss is 5.73. Up to batch 1, the average loss is 5.04. Up to batch 2, the average loss is 5.10. Up to batch 3, the average loss is 5.14. Up to batch 4, the average loss is 5.37. Up to batch 5, the average loss is 5.24. Up to batch 6, the average loss is 5.22. Up to batch 7, the average loss is 5.16.
Использование self.model
атрибута
В дополнении к получению информации журнала , когда один из их методов называются, обратные вызовы имеют доступ к модели , связанные с текущим раундом подготовки / оценок / вывода: self.model
.
Вот из немногих вещей , которые вы можете сделать с self.model
в обратном вызове:
- Набор
self.model.stop_training = True
для тренировки сразу прерывания. - Мутируют гиперпараметры оптимизатора (доступны в качестве
self.model.optimizer
), такие какself.model.optimizer.learning_rate
. - Сохраняйте модель через определенные промежутки времени.
- Запишите вывод
model.predict()
на нескольких испытательных образцов в конце каждой эпохи, чтобы использовать в качестве проверки вменяемости во время тренировки. - Извлекайте визуализации промежуточных функций в конце каждой эпохи, чтобы отслеживать, что модель изучает с течением времени.
- и Т. Д.
Давайте посмотрим на это в действии на нескольких примерах.
Примеры приложений обратного вызова Keras
Ранняя остановка с минимальными потерями
Это первый пример показывает создание Callback
, который останавливает обучение , когда минимум потерь было достигнуто путем установки атрибута self.model.stop_training
(логическое). При желании, вы можете указать аргумент patience
, чтобы указать , сколько эпох мы должны ждать до остановки после достижения локального минимума.
tf.keras.callbacks.EarlyStopping
обеспечивает более полную и общую реализацию.
import numpy as np
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments:
patience: Number of epochs to wait after min has been hit. After this
number of no improvement, training stops.
"""
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
steps_per_epoch=5,
epochs=30,
verbose=0,
callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
Up to batch 0, the average loss is 34.62. Up to batch 1, the average loss is 405.62. Up to batch 2, the average loss is 282.27. Up to batch 3, the average loss is 215.95. Up to batch 4, the average loss is 175.32. The average loss for epoch 0 is 175.32 and mean absolute error is 8.59. Up to batch 0, the average loss is 8.86. Up to batch 1, the average loss is 7.31. Up to batch 2, the average loss is 6.51. Up to batch 3, the average loss is 6.71. Up to batch 4, the average loss is 6.24. The average loss for epoch 1 is 6.24 and mean absolute error is 2.06. Up to batch 0, the average loss is 4.83. Up to batch 1, the average loss is 5.05. Up to batch 2, the average loss is 4.71. Up to batch 3, the average loss is 4.41. Up to batch 4, the average loss is 4.48. The average loss for epoch 2 is 4.48 and mean absolute error is 1.68. Up to batch 0, the average loss is 5.84. Up to batch 1, the average loss is 5.73. Up to batch 2, the average loss is 7.24. Up to batch 3, the average loss is 10.34. Up to batch 4, the average loss is 15.53. The average loss for epoch 3 is 15.53 and mean absolute error is 3.20. Restoring model weights from the end of the best epoch. Epoch 00004: early stopping <keras.callbacks.History at 0x7fd0843bf510>
Планирование скорости обучения
В этом примере мы показываем, как пользовательский обратный вызов может использоваться для динамического изменения скорости обучения оптимизатора в процессе обучения.
См callbacks.LearningRateScheduler
для более общих реализаций.
class CustomLearningRateScheduler(keras.callbacks.Callback):
"""Learning rate scheduler which sets the learning rate according to schedule.
Arguments:
schedule: a function that takes an epoch index
(integer, indexed from 0) and current learning rate
as inputs and returns a new learning rate as output (float).
"""
def __init__(self, schedule):
super(CustomLearningRateScheduler, self).__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, "lr"):
raise ValueError('Optimizer must have a "lr" attribute.')
# Get the current learning rate from model's optimizer.
lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))
LR_SCHEDULE = [
# (epoch to start, learning rate) tuples
(3, 0.05),
(6, 0.01),
(9, 0.005),
(12, 0.001),
]
def lr_schedule(epoch, lr):
"""Helper function to retrieve the scheduled learning rate based on epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
steps_per_epoch=5,
epochs=15,
verbose=0,
callbacks=[
LossAndErrorPrintingCallback(),
CustomLearningRateScheduler(lr_schedule),
],
)
Epoch 00000: Learning rate is 0.1000. Up to batch 0, the average loss is 26.55. Up to batch 1, the average loss is 435.15. Up to batch 2, the average loss is 298.00. Up to batch 3, the average loss is 225.91. Up to batch 4, the average loss is 182.66. The average loss for epoch 0 is 182.66 and mean absolute error is 8.16. Epoch 00001: Learning rate is 0.1000. Up to batch 0, the average loss is 7.30. Up to batch 1, the average loss is 6.22. Up to batch 2, the average loss is 6.76. Up to batch 3, the average loss is 6.37. Up to batch 4, the average loss is 5.98. The average loss for epoch 1 is 5.98 and mean absolute error is 2.01. Epoch 00002: Learning rate is 0.1000. Up to batch 0, the average loss is 4.23. Up to batch 1, the average loss is 4.56. Up to batch 2, the average loss is 4.81. Up to batch 3, the average loss is 4.63. Up to batch 4, the average loss is 4.67. The average loss for epoch 2 is 4.67 and mean absolute error is 1.73. Epoch 00003: Learning rate is 0.0500. Up to batch 0, the average loss is 6.24. Up to batch 1, the average loss is 5.62. Up to batch 2, the average loss is 5.48. Up to batch 3, the average loss is 5.09. Up to batch 4, the average loss is 4.68. The average loss for epoch 3 is 4.68 and mean absolute error is 1.77. Epoch 00004: Learning rate is 0.0500. Up to batch 0, the average loss is 3.38. Up to batch 1, the average loss is 3.83. Up to batch 2, the average loss is 3.53. Up to batch 3, the average loss is 3.64. Up to batch 4, the average loss is 3.76. The average loss for epoch 4 is 3.76 and mean absolute error is 1.54. Epoch 00005: Learning rate is 0.0500. Up to batch 0, the average loss is 3.62. Up to batch 1, the average loss is 3.79. Up to batch 2, the average loss is 3.75. Up to batch 3, the average loss is 3.83. Up to batch 4, the average loss is 4.37. The average loss for epoch 5 is 4.37 and mean absolute error is 1.65. Epoch 00006: Learning rate is 0.0100. Up to batch 0, the average loss is 6.73. Up to batch 1, the average loss is 6.13. Up to batch 2, the average loss is 5.11. Up to batch 3, the average loss is 4.57. Up to batch 4, the average loss is 4.21. The average loss for epoch 6 is 4.21 and mean absolute error is 1.61. Epoch 00007: Learning rate is 0.0100. Up to batch 0, the average loss is 3.37. Up to batch 1, the average loss is 3.83. Up to batch 2, the average loss is 3.80. Up to batch 3, the average loss is 3.50. Up to batch 4, the average loss is 3.31. The average loss for epoch 7 is 3.31 and mean absolute error is 1.42. Epoch 00008: Learning rate is 0.0100. Up to batch 0, the average loss is 5.33. Up to batch 1, the average loss is 4.84. Up to batch 2, the average loss is 4.02. Up to batch 3, the average loss is 3.87. Up to batch 4, the average loss is 3.85. The average loss for epoch 8 is 3.85 and mean absolute error is 1.53. Epoch 00009: Learning rate is 0.0050. Up to batch 0, the average loss is 1.84. Up to batch 1, the average loss is 2.75. Up to batch 2, the average loss is 3.16. Up to batch 3, the average loss is 3.52. Up to batch 4, the average loss is 3.34. The average loss for epoch 9 is 3.34 and mean absolute error is 1.43. Epoch 00010: Learning rate is 0.0050. Up to batch 0, the average loss is 2.36. Up to batch 1, the average loss is 2.91. Up to batch 2, the average loss is 2.63. Up to batch 3, the average loss is 2.93. Up to batch 4, the average loss is 3.17. The average loss for epoch 10 is 3.17 and mean absolute error is 1.36. Epoch 00011: Learning rate is 0.0050. Up to batch 0, the average loss is 3.32. Up to batch 1, the average loss is 3.02. Up to batch 2, the average loss is 2.96. Up to batch 3, the average loss is 2.80. Up to batch 4, the average loss is 2.92. The average loss for epoch 11 is 2.92 and mean absolute error is 1.32. Epoch 00012: Learning rate is 0.0010. Up to batch 0, the average loss is 4.11. Up to batch 1, the average loss is 3.70. Up to batch 2, the average loss is 3.89. Up to batch 3, the average loss is 3.76. Up to batch 4, the average loss is 3.45. The average loss for epoch 12 is 3.45 and mean absolute error is 1.44. Epoch 00013: Learning rate is 0.0010. Up to batch 0, the average loss is 3.38. Up to batch 1, the average loss is 3.34. Up to batch 2, the average loss is 3.26. Up to batch 3, the average loss is 3.56. Up to batch 4, the average loss is 3.62. The average loss for epoch 13 is 3.62 and mean absolute error is 1.44. Epoch 00014: Learning rate is 0.0010. Up to batch 0, the average loss is 2.48. Up to batch 1, the average loss is 2.38. Up to batch 2, the average loss is 2.76. Up to batch 3, the average loss is 2.63. Up to batch 4, the average loss is 2.66. The average loss for epoch 14 is 2.66 and mean absolute error is 1.29. <keras.callbacks.History at 0x7fd08446c290>
Встроенные обратные вызовы Keras
Будьте уверены , чтобы проверить существующие функции обратного вызова Keras, читая API Docs . Приложения включают в себя ведение журнала в CSV, сохранение модели, визуализацию показателей в TensorBoard и многое другое!