Migrate the fault tolerance mechanism

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Fault tolerance refers to a mechanism of periodically saving the states of trackable objects, such as parameters and models. This enables you to recover them in the event of a program/machine failure during training.

This guide first demonstrates how to add fault tolerance to training with tf.estimator.Estimator in TensorFlow 1 by specifying metric saving with tf.estimator.RunConfig. Then, you will learn how to implement fault tolerance for training in Tensorflow 2 in two ways:

Both of these methods will back up and restore the training states in checkpoint files.

Setup

Install tf-nightly, as the frequency of checkpoint saving at a particular step with the save_freq argument in tf.keras.callbacks.BackupAndRestore is introduced from TensorFlow 2.10:

pip install tf-nightly
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
2022-07-19 01:25:29.201393: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-07-19 01:25:29.920946: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-07-19 01:25:29.925996: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-07-19 01:25:29.926012: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

TensorFlow 1: Save checkpoints with tf.estimator.RunConfig

In TensorFlow 1, you can configure a tf.estimator to save checkpoints every step by configuring tf.estimator.RunConfig.

In this example, start by writing a hook that artificially throws an error during the fifth checkpoint:

class InterruptHook(tf1.train.SessionRunHook):
  # A hook for artificially interrupting training.
  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step == 5:
      raise RuntimeError('Interruption')

Next, configure tf.estimator.Estimator to save every checkpoint and use the MNIST dataset:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpuslrce_s', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_9776/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmpfs/tmp/ipykernel_9776/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

Begin training the model. An artificial exception will be raised by the hook you defined earlier.

try:
  classifier.train(input_fn=train_input_fn,
                   hooks=[InterruptHook()],
                   max_steps=10)
except Exception as e:
  print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
2022-07-19 01:25:35.400081: W tensorflow/core/common_runtime/forward_type_inference.cc:332] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:loss = 118.68688, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1064: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
RuntimeError:Interruption

Rebuild the tf.estimator.Estimator using the last saved checkpoint and continue training:

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)
classifier.train(input_fn=train_input_fn,
                   max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpuslrce_s', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpuslrce_s/model.ckpt-6
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1173: get_checkpoint_mtimes (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:loss = 102.9736, step = 6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmpuslrce_s/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 93.64633.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7efe5c69a5e0>

TensorFlow 2: Back up and restore with a callback and Model.fit

In TensorFlow 2, if you use the Keras Model.fit API for training, you can provide the tf.keras.callbacks.BackupAndRestore callback to add the fault tolerance functionality.

To help demonstrate this, first start by defining a Keras Callback class that artificially throws an error during the fourth epoch checkpoint:

class InterruptAtEpoch(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def __init__(self, interrupting_epoch=3):
    self.interrupting_epoch = interrupting_epoch

  def on_epoch_end(self, epoch, log=None):
    if epoch == self.interrupting_epoch:
      raise RuntimeError('Interruption')

Then, define and instantiate a simple Keras model, define the loss function, call Model.compile, and set up a tf.keras.callbacks.BackupAndRestore callback that will save the checkpoints in a temporary directory at epoch boundaries:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'])
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir = log_dir)

Start training the model with Model.fit. During training, checkpoints will be saved thanks to tf.keras.callbacks.BackupAndRestore instantiated above, while the InterruptAtEpoch class will raise an artificial exception to simulate a failure after the fourth epoch.

try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            steps_per_epoch=100,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptAtEpoch()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
100/100 [==============================] - 1s 10ms/step - loss: 0.4583 - accuracy: 0.8716 - val_loss: 0.2223 - val_accuracy: 0.9374
Epoch 2/10
100/100 [==============================] - 1s 7ms/step - loss: 0.2019 - accuracy: 0.9435 - val_loss: 0.1553 - val_accuracy: 0.9537
Epoch 3/10
100/100 [==============================] - 1s 7ms/step - loss: 0.1470 - accuracy: 0.9580 - val_loss: 0.1223 - val_accuracy: 0.9633
Epoch 4/10
 96/100 [===========================>..] - ETA: 0s - loss: 0.1158 - accuracy: 0.9672RuntimeError:Interruption

Next, instantiate the Keras model, call Model.compile, and continue training the model with Model.fit from a previously saved checkpoint:

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            steps_per_epoch=100,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 5/10
100/100 [==============================] - 1s 14ms/step - loss: 0.0960 - accuracy: 0.9725 - val_loss: 0.0889 - val_accuracy: 0.9742
Epoch 6/10
100/100 [==============================] - 0s 4ms/step - loss: 0.0815 - accuracy: 0.9764 - val_loss: 0.0812 - val_accuracy: 0.9754
Epoch 7/10
100/100 [==============================] - 0s 4ms/step - loss: 0.0683 - accuracy: 0.9801 - val_loss: 0.0753 - val_accuracy: 0.9774
Epoch 8/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0595 - accuracy: 0.9830 - val_loss: 0.0700 - val_accuracy: 0.9790
Epoch 9/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0515 - accuracy: 0.9851 - val_loss: 0.0658 - val_accuracy: 0.9807
Epoch 10/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0453 - accuracy: 0.9872 - val_loss: 0.0652 - val_accuracy: 0.9804
<keras.callbacks.History at 0x7efe5c1c93d0>

Define another Callback class that artificially throws an error during the 140th step:

class InterruptAtStep(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def __init__(self, interrupting_step=140):
    self.total_step_count = 0
    self.interrupting_step = interrupting_step

  def on_batch_begin(self, batch, logs=None):
    self.total_step_count += 1

  def on_batch_end(self, batch, logs=None):
    if self.total_step_count == self.interrupting_step:
      print("\nInterrupting at step count", self.total_step_count)
      raise RuntimeError('Interruption')

To make sure the checkpoints are saved every 30 steps, set the save_freq in the BackupAndRestore callback to 30. The InterruptAtStep will raise an artificial exception to simulate a failure at epoch 1 and step 40 (total step count 140). The checkpoint would be last saved at epoch 1 and step 20.

log_dir_2 = tempfile.mkdtemp()

backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir = log_dir_2, save_freq=30
)
model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'])
try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            steps_per_epoch=100,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptAtStep()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
100/100 [==============================] - 1s 10ms/step - loss: 0.4614 - accuracy: 0.8707 - val_loss: 0.2200 - val_accuracy: 0.9372
Epoch 2/10
 20/100 [=====>........................] - ETA: 0s - loss: 0.2394 - accuracy: 0.9320
Interrupting at step count 140
RuntimeError:Interruption

Next, instantiate the Keras model, call Model.compile, and continue training the model with Model.fit from a previously saved checkpoint. Notice that the training starts from epoch 2 and step 21.

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            steps_per_epoch=100,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 2/10
100/100 [==============================] - 1s 13ms/step - loss: 0.1908 - accuracy: 0.9457 - val_loss: 0.1537 - val_accuracy: 0.9559
Epoch 3/10
100/100 [==============================] - 0s 5ms/step - loss: 0.1457 - accuracy: 0.9587 - val_loss: 0.1234 - val_accuracy: 0.9642
Epoch 4/10
100/100 [==============================] - 0s 4ms/step - loss: 0.1150 - accuracy: 0.9675 - val_loss: 0.1044 - val_accuracy: 0.9697
Epoch 5/10
100/100 [==============================] - 0s 4ms/step - loss: 0.0931 - accuracy: 0.9740 - val_loss: 0.0902 - val_accuracy: 0.9729
Epoch 6/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0804 - accuracy: 0.9766 - val_loss: 0.0817 - val_accuracy: 0.9767
Epoch 7/10
100/100 [==============================] - 0s 4ms/step - loss: 0.0677 - accuracy: 0.9807 - val_loss: 0.0761 - val_accuracy: 0.9773
Epoch 8/10
100/100 [==============================] - 0s 4ms/step - loss: 0.0592 - accuracy: 0.9826 - val_loss: 0.0708 - val_accuracy: 0.9779
Epoch 9/10
100/100 [==============================] - 0s 5ms/step - loss: 0.0508 - accuracy: 0.9854 - val_loss: 0.0687 - val_accuracy: 0.9787
Epoch 10/10
100/100 [==============================] - 0s 4ms/step - loss: 0.0442 - accuracy: 0.9877 - val_loss: 0.0649 - val_accuracy: 0.9796
<keras.callbacks.History at 0x7efe4c688a60>

TensorFlow 2: Write manual checkpoints with a custom training loop

If you use a custom training loop in TensorFlow 2, you can implement a fault tolerance mechanism with the tf.train.Checkpoint and tf.train.CheckpointManager APIs.

This example demonstrates how to:

Start by defining and instantiating the Keras model, the optimizer, and the loss function. Then, create a Checkpoint that manages two objects with trackable states (the model and the optimizer), as well as a CheckpointManager for logging and keeping several checkpoints in a temporary directory.

model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, log_dir, max_to_keep=2)

Now, implement a custom training loop where after the first epoch every time a new epoch starts the last checkpoint is loaded:

for epoch in range(epochs):
  if epoch > 0:
      tf.train.load_checkpoint(save_path)
  print(f"\nStart of epoch {epoch}")

  for step in range(steps_per_epoch):
    with tf.GradientTape() as tape:

      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)

      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))

    save_path = checkpoint_manager.save()
    print(f"Checkpoint saved to {save_path}")
    print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-1
Training loss at step 0: 2.3609695434570312
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-2
Training loss at step 1: 2.358873128890991
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-3
Training loss at step 2: 2.3592422008514404
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-4
Training loss at step 3: 2.3570125102996826
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-5
Training loss at step 4: 2.3560245037078857

Start of epoch 1
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-6
Training loss at step 0: 2.35449481010437
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-7
Training loss at step 1: 2.3526718616485596
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-8
Training loss at step 2: 2.3518872261047363
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-9
Training loss at step 3: 2.349724292755127
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-10
Training loss at step 4: 2.3485031127929688

Start of epoch 2
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-11
Training loss at step 0: 2.3468377590179443
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-12
Training loss at step 1: 2.3468360900878906
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-13
Training loss at step 2: 2.3430659770965576
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-14
Training loss at step 3: 2.3434078693389893
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-15
Training loss at step 4: 2.3411684036254883

Start of epoch 3
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-16
Training loss at step 0: 2.3389217853546143
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-17
Training loss at step 1: 2.338501453399658
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-18
Training loss at step 2: 2.338675022125244
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-19
Training loss at step 3: 2.335449695587158
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-20
Training loss at step 4: 2.335257053375244

Start of epoch 4
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-21
Training loss at step 0: 2.332885503768921
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-22
Training loss at step 1: 2.3322277069091797
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-23
Training loss at step 2: 2.3321328163146973
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-24
Training loss at step 3: 2.3299806118011475
Checkpoint saved to /tmpfs/tmp/tmprgkli2sv/ckpt-25
Training loss at step 4: 2.3271031379699707

Next steps

To learn more about fault tolerance and checkpointing in TensorFlow 2, consider the following documentation:

You may also find the following material related to distributed training useful: