Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

Migrate the fault tolerance mechanism

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Fault tolerance refers to a mechanism of periodically saving the states of trackable objects, such as parameters and models. This enables you to recover them in the event of a program/machine failure during training.

This guide first demonstrates how to add fault tolerance to training with tf.estimator.Estimator in TensorFlow 1 by specifying metric saving with tf.estimator.RunConfig. Then, you will learn how to implement fault tolerance for training in Tensorflow 2 in two ways:

Both of these methods will back up and restore the training states in checkpoint files.

Setup

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1: Save checkpoints with tf.estimator.RunConfig

In TensorFlow 1, you can configure a tf.estimator to save checkpoints every step by configuring tf.estimator.RunConfig.

In this example, start by writing a hook that artificially throws an error during the fifth checkpoint:

class InterruptHook(tf1.train.SessionRunHook):
  # A hook for artificially interrupting training.
  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step == 5:
      raise RuntimeError('Interruption')

Next, configure tf.estimator.Estimator to save every checkpoint and use the MNIST dataset:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpk5_4cfvv', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmp/ipykernel_13774/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmp/ipykernel_13774/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

Begin training the model. An artificial exception will be raised by the hook you defined earlier.

try:
  classifier.train(input_fn=train_input_fn,
                   hooks=[InterruptHook()],
                   max_steps=10)
except Exception as e:
  print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:907: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:loss = 117.26719, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpk5_4cfvv/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:971: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
RuntimeError:Interruption

Rebuild the tf.estimator.Estimator using the last saved checkpoint and continue training:

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)
classifier.train(input_fn=train_input_fn,
                   max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpk5_4cfvv', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpk5_4cfvv/model.ckpt-6
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1078: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:loss = 103.11247, step = 6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 86.68358.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f9e820cff50>

TensorFlow 2: Back up and restore with a callback and Model.fit

In TensorFlow 2, if you use the Keras Model.fit API for training, you can provide the tf.keras.callbacks.experimental.BackupAndRestore callback to add the fault tolerance functionality.

To help demonstrate this, let's first start by defining a callback class that artificially throws an error during the fifth checkpoint:

class InterruptingCallback(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def on_epoch_end(self, epoch, log=None):
    if epoch == 4:
      raise RuntimeError('Interruption')

Then, define and instantiate a simple Keras model, define the loss function, call Model.compile, and set up a tf.keras.callbacks.experimental.BackupAndRestore callback that will save the checkpoints in a temporary directory:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

backup_restore_callback = tf.keras.callbacks.experimental.BackupAndRestore(
    backup_dir = log_dir
)

Now, start training the model with Model.fit. During training, checkpoints will be saved thanks to the backup_restore_callback defined above, while the InterruptingCallback will raise an artificial exception to simulate a failure.

try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2167 - accuracy: 0.9352 - val_loss: 0.0944 - val_accuracy: 0.9725
Epoch 2/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0965 - accuracy: 0.9703 - val_loss: 0.0823 - val_accuracy: 0.9735
Epoch 3/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0684 - accuracy: 0.9780 - val_loss: 0.0727 - val_accuracy: 0.9756
Epoch 4/10
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0542 - accuracy: 0.9829 - val_loss: 0.0676 - val_accuracy: 0.9790
Epoch 5/10
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0442 - accuracy: 0.9855 - val_loss: 0.0634 - val_accuracy: 0.9807
RuntimeError:Interruption

Next, instantiate the Keras model, call Model.compile, and continue training the model with Model.fit from a previously saved checkpoint:

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 6/10
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0367 - accuracy: 0.9876 - val_loss: 0.0725 - val_accuracy: 0.9794
Epoch 7/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0313 - accuracy: 0.9894 - val_loss: 0.0787 - val_accuracy: 0.9779
Epoch 8/10
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0288 - accuracy: 0.9905 - val_loss: 0.0820 - val_accuracy: 0.9782
Epoch 9/10
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0252 - accuracy: 0.9914 - val_loss: 0.0639 - val_accuracy: 0.9830
Epoch 10/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0235 - accuracy: 0.9924 - val_loss: 0.0788 - val_accuracy: 0.9803
<keras.callbacks.History at 0x7f9e82536990>

TensorFlow 2: Write manual checkpoints with a custom training loop

If you use a custom training loop in TensorFlow 2, you can implement a fault tolerance mechanism with the tf.train.Checkpoint and tf.train.CheckpointManager APIs.

This example demonstrates how to:

Start by defining and instantiating the Keras model, the optimizer, and the loss function. Then, create a Checkpoint that manages two objects with trackable states (the model and the optimizer), as well as a CheckpointManager for logging and keeping several checkpoints in a temporary directory.

model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, log_dir, max_to_keep=2)

Now, implement a custom training loop where after the first epoch every time a new epoch starts the last checkpoint is loaded:

for epoch in range(epochs):
  if epoch > 0:
      tf.train.load_checkpoint(save_path)
  print(f"\nStart of epoch {epoch}")

  for step in range(steps_per_epoch):
    with tf.GradientTape() as tape:

      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)

      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))

    save_path = checkpoint_manager.save()
    print(f"Checkpoint saved to {save_path}")
    print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-1
Training loss at step 0: 2.4602103233337402
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-2
Training loss at step 1: 2.4579155445098877
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-3
Training loss at step 2: 2.4571962356567383
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-4
Training loss at step 3: 2.456108570098877
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-5
Training loss at step 4: 2.4541022777557373

Start of epoch 1
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-6
Training loss at step 0: 2.4518723487854004
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-7
Training loss at step 1: 2.451997995376587
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-8
Training loss at step 2: 2.450746774673462
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-9
Training loss at step 3: 2.4489808082580566
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-10
Training loss at step 4: 2.4467883110046387

Start of epoch 2
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-11
Training loss at step 0: 2.445439100265503
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-12
Training loss at step 1: 2.442873477935791
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-13
Training loss at step 2: 2.443373680114746
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-14
Training loss at step 3: 2.4398140907287598
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-15
Training loss at step 4: 2.4389309883117676

Start of epoch 3
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-16
Training loss at step 0: 2.437243938446045
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-17
Training loss at step 1: 2.4370715618133545
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-18
Training loss at step 2: 2.435986042022705
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-19
Training loss at step 3: 2.4329538345336914
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-20
Training loss at step 4: 2.431180953979492

Start of epoch 4
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-21
Training loss at step 0: 2.4317142963409424
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-22
Training loss at step 1: 2.43074631690979
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-23
Training loss at step 2: 2.428147077560425
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-24
Training loss at step 3: 2.4258265495300293
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-25
Training loss at step 4: 2.4255685806274414

Next steps

To learn more about fault tolerance and checkpointing in TensorFlow 2, consider the following documentation:

You may also find the following material related to distributed training useful: