迁移容错机制

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

容错是指定期保存参数和模型等可跟踪对象的状态的机制。这样,您便能够在训练期间出现程序/机器故障时恢复它们。

本指南首先演示了如何通过使用 tf.estimator.RunConfig 指定指标保存以在 TensorFlow 1 中使用 tf.estimator.Estimator 向训练添加容错。随后,您将学习如何通过以下两种方式在 Tensorflow 2 中实现容错训练:

这两种方式都会备份和恢复检查点文件中的训练状态。

安装

安装 tf-nightly,因为使用 tf.keras.callbacks.BackupAndRestore 中的 save_freq 参数设置特定步骤保存检查点的频率是从 TensorFlow 2.10 引入的:

pip install tf-nightly
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
2022-12-14 20:24:11.351822: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

TensorFlow 1:使用 tf.estimator.RunConfig 保存检查点

在 TensorFlow 1 中,可以配置 tf.estimator,随后通过配置 tf.estimator.RunConfig 在每一步保存检查点。

在此示例中,首先编写一个在第五个检查点期间人为抛出错误的钩子:

class InterruptHook(tf1.train.SessionRunHook):
  # A hook for artificially interrupting training.
  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step == 5:
      raise RuntimeError('Interruption')

接下来,配置 tf.estimator.Estimator 以保存每个检查点并使用 MNIST 数据集:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_25509/314197976.py:1: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_25509/314197976.py:2: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_25509/314197976.py:7: DNNClassifier.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:807: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp7e95s18u', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_25509/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmpfs/tmp/ipykernel_25509/314197976.py:17: numpy_input_fn (from tensorflow_estimator.python.estimator.inputs.numpy_io) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.

开始训练模型。您之前定义的钩子将引发人为异常。

try:
  classifier.train(input_fn=train_input_fn,
                   hooks=[InterruptHook()],
                   max_steps=10)
except Exception as e:
  print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_25509/2587623597.py:3: object.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:446: dnn_logit_fn_builder (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1414: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1417: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1454: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:910: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
2022-12-14 20:24:18.394676: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:loss = 120.54729, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1067: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
RuntimeError:Interruption

使用最后保存的检查点重新构建 tf.estimator.Estimator 并继续训练:

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)
classifier.train(input_fn=train_input_fn,
                   max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp7e95s18u', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp7e95s18u/model.ckpt-6
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1176: get_checkpoint_mtimes (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:loss = 99.52451, step = 6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmp7e95s18u/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 98.06565.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f8add4c35e0>

TensorFlow 2:使用回调和 Model.fit 备份和恢复

在 TensorFlow 2 中,如果使用 Keras Model.fit API 进行训练,则可以提供 tf.keras.callbacks.BackupAndRestore 回调来添加容错功能。

为了帮助演示这一点,首先定义一个 Keras Callback 类,该类会在第四个周期检查点期间人为抛出错误:

class InterruptAtEpoch(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def __init__(self, interrupting_epoch=3):
    self.interrupting_epoch = interrupting_epoch

  def on_epoch_end(self, epoch, log=None):
    if epoch == self.interrupting_epoch:
      raise RuntimeError('Interruption')

然后,定义并实例化一个简单的 Keras 模型,定义损失函数,调用 Model.compile 并设置一个 tf.keras.callbacks.BackupAndRestore 回调,它会将检查点保存在周期边界的临时目录中:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'])
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir = log_dir)

开始使用 Model.fit 训练模型。在训练期间,由于上面实例化的 tf.keras.callbacks.BackupAndRestore 将保存检查点,而 InterruptAtEpoch 类将引发人为异常来模拟第四个周期后的失败。

try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            steps_per_epoch=100,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptAtEpoch()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
100/100 [==============================] - 2s 11ms/step - loss: 0.4629 - accuracy: 0.8704 - val_loss: 0.2217 - val_accuracy: 0.9375
Epoch 2/10
100/100 [==============================] - 1s 8ms/step - loss: 0.2015 - accuracy: 0.9429 - val_loss: 0.1621 - val_accuracy: 0.9527
Epoch 3/10
100/100 [==============================] - 1s 8ms/step - loss: 0.1474 - accuracy: 0.9585 - val_loss: 0.1228 - val_accuracy: 0.9636
Epoch 4/10
 91/100 [==========================>...] - ETA: 0s - loss: 0.1182 - accuracy: 0.9661RuntimeError:Interruption

接下来,实例化 Keras 模型,调用 Model.compile,并从之前保存的检查点继续使用 Model.fit 训练模型:

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            steps_per_epoch=100,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 5/10
100/100 [==============================] - 2s 19ms/step - loss: 0.0956 - accuracy: 0.9733 - val_loss: 0.0925 - val_accuracy: 0.9727
Epoch 6/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0801 - accuracy: 0.9775 - val_loss: 0.0824 - val_accuracy: 0.9759
Epoch 7/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0680 - accuracy: 0.9810 - val_loss: 0.0747 - val_accuracy: 0.9775
Epoch 8/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0599 - accuracy: 0.9829 - val_loss: 0.0736 - val_accuracy: 0.9768
Epoch 9/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0521 - accuracy: 0.9853 - val_loss: 0.0710 - val_accuracy: 0.9783
Epoch 10/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0463 - accuracy: 0.9866 - val_loss: 0.0643 - val_accuracy: 0.9791
<keras.callbacks.History at 0x7f8a54329d90>

定义另一个 Callback 类,该类会在第 140 步期间人为抛出错误:

class InterruptAtStep(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def __init__(self, interrupting_step=140):
    self.total_step_count = 0
    self.interrupting_step = interrupting_step

  def on_batch_begin(self, batch, logs=None):
    self.total_step_count += 1

  def on_batch_end(self, batch, logs=None):
    if self.total_step_count == self.interrupting_step:
      print("\nInterrupting at step count", self.total_step_count)
      raise RuntimeError('Interruption')

注:本部分使用了仅在 Tensorflow 2.10 发布后才能在 tf-nightly 中可用的功能。

要确保检查点每 30 个步骤保存一次,请将 BackupAndRestore 回调中的 save_freq 设置为 30InterruptAtStep 将引发一个人为的异常来模拟周期 1 和步骤 40 的失败(总步数为 140)。最后会在周期 1 和步骤 20 保存检查点。

log_dir_2 = tempfile.mkdtemp()

backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir = log_dir_2, save_freq=30
)
model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'])
try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            steps_per_epoch=100,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptAtStep()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
100/100 [==============================] - 2s 11ms/step - loss: 0.4761 - accuracy: 0.8646 - val_loss: 0.2292 - val_accuracy: 0.9344
Epoch 2/10
 27/100 [=======>......................] - ETA: 0s - loss: 0.2342 - accuracy: 0.9328
Interrupting at step count 140
RuntimeError:Interruption

接下来,实例化 Keras 模型,调用 Model.compile,并从之前保存的检查点继续使用 Model.fit 训练模型。请注意,训练从周期 2 和步骤 21 开始。

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            steps_per_epoch=100,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 2/10
100/100 [==============================] - 2s 18ms/step - loss: 0.1969 - accuracy: 0.9439 - val_loss: 0.1629 - val_accuracy: 0.9544
Epoch 3/10
100/100 [==============================] - 0s 5ms/step - loss: 0.1568 - accuracy: 0.9555 - val_loss: 0.1271 - val_accuracy: 0.9632
Epoch 4/10
100/100 [==============================] - 0s 5ms/step - loss: 0.1187 - accuracy: 0.9663 - val_loss: 0.1053 - val_accuracy: 0.9685
Epoch 5/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0977 - accuracy: 0.9724 - val_loss: 0.0952 - val_accuracy: 0.9710
Epoch 6/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0822 - accuracy: 0.9763 - val_loss: 0.0864 - val_accuracy: 0.9741
Epoch 7/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0716 - accuracy: 0.9799 - val_loss: 0.0795 - val_accuracy: 0.9751
Epoch 8/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0608 - accuracy: 0.9824 - val_loss: 0.0719 - val_accuracy: 0.9776
Epoch 9/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0513 - accuracy: 0.9857 - val_loss: 0.0704 - val_accuracy: 0.9790
Epoch 10/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0493 - accuracy: 0.9858 - val_loss: 0.0677 - val_accuracy: 0.9793
<keras.callbacks.History at 0x7f8a5c2a5670>

TensorFlow 2:使用自定义训练循环编写手动检查点

如果您在 TensorFlow 2 中使用自定义训练循环,则可以使用 tf.train.Checkpointtf.train.CheckpointManager API 实现容错机制。

此示例演示了如何执行以下操作:

首先,定义和实例化 Keras 模型、优化器和损失函数。然后,创建一个 Checkpoint 来管理两个具有可跟踪状态的对象(模型和优化器),以及一个 CheckpointManager 来记录多个检查点并将它们保存在临时目录中。

model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, log_dir, max_to_keep=2)

现在,实现一个自定义训练循环,在第一个周期之后,每次新周期开始时都会加载最后一个检查点:

for epoch in range(epochs):
  if epoch > 0:
      tf.train.load_checkpoint(save_path)
  print(f"\nStart of epoch {epoch}")

  for step in range(steps_per_epoch):
    with tf.GradientTape() as tape:

      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)

      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))

    save_path = checkpoint_manager.save()
    print(f"Checkpoint saved to {save_path}")
    print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-1
Training loss at step 0: 2.4203763008117676
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-2
Training loss at step 1: 2.420546770095825
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-3
Training loss at step 2: 2.4176888465881348
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-4
Training loss at step 3: 2.4155921936035156
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-5
Training loss at step 4: 2.4153852462768555

Start of epoch 1
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-6
Training loss at step 0: 2.4146769046783447
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-7
Training loss at step 1: 2.4105751514434814
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-8
Training loss at step 2: 2.4090170860290527
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-9
Training loss at step 3: 2.407325029373169
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-10
Training loss at step 4: 2.406435489654541

Start of epoch 2
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-11
Training loss at step 0: 2.4057834148406982
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-12
Training loss at step 1: 2.4041085243225098
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-13
Training loss at step 2: 2.401327610015869
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-14
Training loss at step 3: 2.4010281562805176
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-15
Training loss at step 4: 2.398888111114502

Start of epoch 3
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-16
Training loss at step 0: 2.3979201316833496
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-17
Training loss at step 1: 2.396275043487549
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-18
Training loss at step 2: 2.3937087059020996
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-19
Training loss at step 3: 2.393911361694336
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-20
Training loss at step 4: 2.3919384479522705

Start of epoch 4
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-21
Training loss at step 0: 2.389833927154541
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-22
Training loss at step 1: 2.3890221118927
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-23
Training loss at step 2: 2.3855605125427246
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-24
Training loss at step 3: 2.3858296871185303
Checkpoint saved to /tmpfs/tmp/tmpevop81wu/ckpt-25
Training loss at step 4: 2.3846724033355713

后续步骤

要详细了解 TensorFlow 2 中的容错和检查点,请查看以下文档:

此外,您可能还会发现下列与分布式训练相关的材料十分有用: