Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

Multi-worker training with Keras

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This tutorial demonstrates how to perform multi-worker distributed training with a Keras model and the Model.fit API using the tf.distribute.Strategy API—specifically the tf.distribute.MultiWorkerMirroredStrategy class. With the help of this strategy, a Keras model that was designed to run on a single-worker can seamlessly work on multiple workers with minimal code changes.

For those interested in a deeper understanding of tf.distribute.Strategy APIs, the Distributed training in TensorFlow guide is available for an overview of the distribution strategies TensorFlow supports.

To learn how to use the MultiWorkerMirroredStrategy with Keras and a custom training loop, refer to Custom training loop with Keras and MultiWorkerMirroredStrategy.

Note that the purpose of this tutorial is to demonstrate a minimal multi-worker example with two workers.

Setup

Start with some necessary imports:

import json
import os
import sys

Before importing TensorFlow, make a few changes to the environment:

  1. Disable all GPUs. This prevents errors caused by the workers all trying to use the same GPU. In a real-world application, each worker would be on a different machine.
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  1. Reset the TF_CONFIG environment variable (you'll learn more about this later):
os.environ.pop('TF_CONFIG', None)
  1. Make sure that the current directory is on Python's path—this allows the notebook to import the files written by %%writefile later:
if '.' not in sys.path:
  sys.path.insert(0, '.')

Now import TensorFlow:

import tensorflow as tf

Dataset and model definition

Next, create an mnist.py file with a simple model and dataset setup. This Python file will be used by the worker-processes in this tutorial:

%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the [0, 255] range.
  # You need to convert them to float32 with values in the [0, 1] range.
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model
Writing mnist.py

Model training on a single worker

Try training the model for a small number of epochs and observe the results of a single worker to make sure everything works correctly. As training progresses, the loss should drop and the accuracy should increase.

import mnist

batch_size = 64
single_worker_dataset = mnist.mnist_dataset(batch_size)
single_worker_model = mnist.build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
2021-08-20 01:21:51.478839: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-08-20 01:21:51.478914: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-2087993482
2021-08-20 01:21:51.478928: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-2087993482
2021-08-20 01:21:51.479029: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 470.57.2
2021-08-20 01:21:51.479060: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 470.57.2
2021-08-20 01:21:51.479067: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 470.57.2
2021-08-20 01:21:51.480364: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Epoch 1/3
 1/70 [..............................] - ETA: 26s - loss: 2.3067 - accuracy: 0.0469
2021-08-20 01:21:52.316481: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
70/70 [==============================] - 1s 12ms/step - loss: 2.2829 - accuracy: 0.1667
Epoch 2/3
70/70 [==============================] - 1s 12ms/step - loss: 2.2281 - accuracy: 0.3842
Epoch 3/3
70/70 [==============================] - 1s 12ms/step - loss: 2.1625 - accuracy: 0.5348
<keras.callbacks.History at 0x7f633d957390>

Multi-worker configuration

Now let's enter the world of multi-worker training.

A cluster with jobs and tasks

In TensorFlow, distributed training involves: a 'cluster' with several jobs, and each of the jobs may have one or more 'task's.

You will need the TF_CONFIG configuration environment variable for training on multiple machines, each of which possibly has a different role. TF_CONFIG is a JSON string used to specify the cluster configuration for each worker that is part of the cluster.

There are two components of a TF_CONFIG variable: 'cluster' and 'task'.

  • A 'cluster' is the same for all workers and provides information about the training cluster, which is a dict consisting of different types of jobs, such as 'worker' or 'chief'.

    • In multi-worker training with tf.distribute.MultiWorkerMirroredStrategy, there is usually one 'worker' that takes on responsibilities, such as saving a checkpoint and writing a summary file for TensorBoard, in addition to what a regular 'worker' does. Such 'worker' is referred to as the chief worker (with a job name 'chief').
    • It is customary for the 'chief' to have 'index' 0 be appointed to (in fact, this is how tf.distribute.Strategy is implemented).
  • A 'task' provides information of the current task and is different for each worker. It specifies the 'type' and 'index' of that worker.

Below is an example configuration:

tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

Here is the same TF_CONFIG serialized as a JSON string:

json.dumps(tf_config)
'{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

Note thattf_config is just a local variable in Python. To be able to use it for a training configuration, this dict needs to be serialized as a JSON and placed in a TF_CONFIG environment variable.

In the example configuration above, you set the task 'type' to 'worker' and the task 'index' to 0. Therefore, this machine is the first worker. It will be appointed as the 'chief' worker and do more work than the others.

For illustration purposes, this tutorial shows how you may set up a TF_CONFIG variable with two workers on a localhost.

In practice, you would create multiple workers on external IP addresses/ports and set a TF_CONFIG variable on each worker accordingly.

In this tutorial, you will use two workers:

  • The first ('chief') worker's TF_CONFIG is shown above.
  • For the second worker, you will set tf_config['task']['index']=1

Environment variables and subprocesses in notebooks

Subprocesses inherit environment variables from their parent.

For example, you can set an environment variable in this Jupyter Notebook process as follows:

os.environ['GREETINGS'] = 'Hello TensorFlow!'

Then, you can access the environment variable from a subprocesses:

echo ${GREETINGS}
Hello TensorFlow!

In the next section, you'll use a similar method to pass the TF_CONFIG to the worker subprocesses. In a real-world scenario, you wouldn't launch your jobs this way, but it's sufficient in this example.

Choose the right strategy

In TensorFlow, there are two main forms of distributed training:

  • Synchronous training, where the steps of training are synced across the workers and replicas, and
  • Asynchronous training, where the training steps are not strictly synced (for example, parameter server training).

This tutorial demonstrates how to perform synchronous multi-worker training using an instance of tf.distribute.MultiWorkerMirroredStrategy.

MultiWorkerMirroredStrategy creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy guide has more details about this strategy.

strategy = tf.distribute.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO

MultiWorkerMirroredStrategy provides multiple implementations via the CommunicationOptions parameter: 1) RING implements ring-based collectives using gRPC as the cross-host communication layer; 2) NCCL uses the NVIDIA Collective Communication Library to implement collectives; and 3) AUTO defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster.

Train the model

With the integration of tf.distribute.Strategy API into tf.keras, the only change you will make to distribute the training to multiple-workers is enclosing the model building and model.compile() call inside strategy.scope(). The distribution strategy's scope dictates how and where the variables are created, and in the case of MultiWorkerMirroredStrategy, the variables created are MirroredVariables, and they are replicated on each of the workers.

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = mnist.build_and_compile_cnn_model()

To actually run with MultiWorkerMirroredStrategy you'll need to run worker processes and pass a TF_CONFIG to them.

Like the mnist.py file written earlier, here is the main.py that each of the workers will run:

%%writefile main.py

import os
import json

import tensorflow as tf
import mnist

per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist.mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = mnist.build_and_compile_cnn_model()


multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
Writing main.py

In the code snippet above note that the global_batch_size, which gets passed to Dataset.batch, is set to per_worker_batch_size * num_workers. This ensures that each worker processes batches of per_worker_batch_size examples regardless of the number of workers.

The current directory now contains both Python files:

ls *.py
main.py
mnist.py

So json-serialize the TF_CONFIG and add it to the environment variables:

os.environ['TF_CONFIG'] = json.dumps(tf_config)

Now, you can launch a worker process that will run the main.py and use the TF_CONFIG:

# first kill any previous runs
%killbgscripts
All background processes were killed.
python main.py &> job_0.log

There are a few things to note about the above command:

  1. It uses the %%bash which is a notebook "magic" to run some bash commands.
  2. It uses the --bg flag to run the bash process in the background, because this worker will not terminate. It waits for all the workers before it starts.

The backgrounded worker process won't print output to this notebook, so the &> redirects its output to a file so that you can inspect what happened in a log file later.

So, wait a few seconds for the process to start up:

import time
time.sleep(10)

Now, inspect what's been output to the worker's log file so far:

cat job_0.log
2021-08-20 01:21:57.459034: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-08-20 01:21:57.459133: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-2087993482
2021-08-20 01:21:57.459414: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-2087993482
2021-08-20 01:21:57.459531: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 470.57.2
2021-08-20 01:21:57.459575: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 470.57.2
2021-08-20 01:21:57.459586: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 470.57.2
2021-08-20 01:21:57.460413: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-20 01:21:57.466180: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:12345, 1 -> localhost:23456}
2021-08-20 01:21:57.466667: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:12345

The last line of the log file should say: Started server with target: grpc://localhost:12345. The first worker is now ready, and is waiting for all the other worker(s) to be ready to proceed.

So update the tf_config for the second worker's process to pick up:

tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)

Launch the second worker. This will start the training since all the workers are active (so there's no need to background this process):

python main.py
Epoch 1/3
70/70 [==============================] - 6s 54ms/step - loss: 2.2796 - accuracy: 0.1292
Epoch 2/3
70/70 [==============================] - 4s 51ms/step - loss: 2.2285 - accuracy: 0.2898
Epoch 3/3
70/70 [==============================] - 4s 54ms/step - loss: 2.1706 - accuracy: 0.4835
2021-08-20 01:22:07.529925: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-08-20 01:22:07.529987: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-2087993482
2021-08-20 01:22:07.529996: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-2087993482
2021-08-20 01:22:07.530089: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 470.57.2
2021-08-20 01:22:07.530125: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 470.57.2
2021-08-20 01:22:07.530136: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 470.57.2
2021-08-20 01:22:07.530785: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-20 01:22:07.536395: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:12345, 1 -> localhost:23456}
2021-08-20 01:22:07.536968: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:23456
2021-08-20 01:22:08.764867: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:695] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_INT64
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
      }
      shape {
      }
    }
  }
}

2021-08-20 01:22:08.983898: W tensorflow/core/framework/dataset.cc:679] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
2021-08-20 01:22:08.985655: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)

If you recheck the logs written by the first worker, you'll learn that it participated in training that model:

cat job_0.log
2021-08-20 01:21:57.459034: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-08-20 01:21:57.459133: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-2087993482
2021-08-20 01:21:57.459414: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-2087993482
2021-08-20 01:21:57.459531: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 470.57.2
2021-08-20 01:21:57.459575: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 470.57.2
2021-08-20 01:21:57.459586: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 470.57.2
2021-08-20 01:21:57.460413: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-20 01:21:57.466180: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> localhost:12345, 1 -> localhost:23456}
2021-08-20 01:21:57.466667: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:427] Started server with target: grpc://localhost:12345
2021-08-20 01:22:08.759563: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:695] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_INT64
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
      }
      shape {
      }
    }
  }
}

2021-08-20 01:22:08.976883: W tensorflow/core/framework/dataset.cc:679] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
2021-08-20 01:22:08.978435: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/3
70/70 [==============================] - 6s 54ms/step - loss: 2.2796 - accuracy: 0.1292
Epoch 2/3
70/70 [==============================] - 4s 52ms/step - loss: 2.2285 - accuracy: 0.2898
Epoch 3/3
70/70 [==============================] - 4s 54ms/step - loss: 2.1706 - accuracy: 0.4835

Unsurprisingly, this ran slower than the test run at the beginning of this tutorial.

Running multiple workers on a single machine only adds overhead.

The goal here was not to improve the training time, but only to give an example of multi-worker training.

# Delete the `TF_CONFIG`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts
All background processes were killed.

Multi-worker training in depth

So far, you have learned how to perform a basic multi-worker setup.

During the rest of the tutorial, you will learn about other factors, which may be useful or important for real use cases, in detail.

Dataset sharding

In multi-worker training, dataset sharding is needed to ensure convergence and performance.

The example in the previous section relies on the default autosharding provided by the tf.distribute.Strategy API. You can control the sharding by setting the tf.data.experimental.AutoShardPolicy of the tf.data.experimental.DistributeOptions.

To learn more about auto-sharding, refer to the Distributed input guide.

Here is a quick example of how to turn the auto sharding off, so that each replica processes every example (not recommended):

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

global_batch_size = 64
multi_worker_dataset = mnist.mnist_dataset(batch_size=64)
dataset_no_auto_shard = multi_worker_dataset.with_options(options)

Evaluation

If you pass the validation_data into Model.fit, it will alternate between training and evaluation for each epoch. The evaluation taking the validation_data is distributed across the same set of workers and the evaluation results are aggregated and available for all workers.

Similar to training, the validation dataset is automatically sharded at the file level. You need to set a global batch size in the validation dataset and set the validation_steps.

A repeated dataset is also recommended for evaluation.

Alternatively, you can also create another task that periodically reads checkpoints and runs the evaluation. This is what Estimator does. But this is not a recommended way to perform evaluation and thus its details are omitted.

Performance

You now have a Keras model that is all set up to run in multiple workers with the MultiWorkerMirroredStrategy.

To tweak performance of multi-worker training, you can try the following:

  • tf.distribute.MultiWorkerMirroredStrategy provides multiple collective communication implementations:

    • RING implements ring-based collectives using gRPC as the cross-host communication layer.
    • NCCL uses the NVIDIA Collective Communication Library to implement collectives.
    • AUTO defers the choice to the runtime.

    The best choice of collective implementation depends upon the number of GPUs, the type of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify the communication_options parameter of MultiWorkerMirroredStrategy's constructor. For example:

    communication_options=tf.distribute.experimental.CommunicationOptions(implementation=tf.distribute.experimental.CollectiveCommunication.NCCL)
    
  • Cast the variables to tf.float if possible:

    • The official ResNet model includes an example of how this can be done.

Fault tolerance

In synchronous training, the cluster would fail if one of the workers fails and no failure-recovery mechanism exists.

Using Keras with tf.distribute.Strategy comes with the advantage of fault tolerance in cases where workers die or are otherwise unstable. You can do this by preserving the training state in the distributed file system of your choice, such that upon a restart of the instance that previously failed or preempted, the training state is recovered.

When a worker becomes unavailable, other workers will fail (possibly after a timeout). In such cases, the unavailable worker needs to be restarted, as well as other workers that have failed.

ModelCheckpoint callback

ModelCheckpoint callback no longer provides fault tolerance functionality, please use BackupAndRestore callback instead.

The ModelCheckpoint callback can still be used to save checkpoints. But with this, if training was interrupted or successfully finished, in order to continue training from the checkpoint, the user is responsible to load the model manually.

Optionally the user can choose to save and restore model/weights outside ModelCheckpoint callback.

Model saving and loading

To save your model using model.save or tf.saved_model.save, the saving destination needs to be different for each worker.

  • For non-chief workers, you will need to save the model to a temporary directory.
  • For the chief, you will need to save to the provided model directory.

The temporary directories on the worker need to be unique to prevent errors resulting from multiple workers trying to write to the same location.

The model saved in all the directories is identical, and typically only the model saved by the chief should be referenced for restoring or serving.

You should have some cleanup logic that deletes the temporary directories created by the workers once your training has completed.

The reason for saving on the chief and workers at the same time is because you might be aggregating variables during checkpointing which requires both the chief and workers to participate in the allreduce communication protocol. On the other hand, letting chief and workers save to the same model directory will result in errors due to contention.

Using the MultiWorkerMirroredStrategy, the program is run on every worker, and in order to know whether the current worker is chief, it takes advantage of the cluster resolver object that has attributes task_type and task_id:

  • task_type tells you what the current job is (e.g. 'worker').
  • task_id tells you the identifier of the worker.
  • The worker with task_id == 0 is designated as the chief worker.

In the code snippet below, the write_filepath function provides the file path to write, which depends on the the worker's task_id:

  • For the chief worker (with task_id == 0), it writes to the original file path.
  • For other workers, it creates a temporary directory—temp_dir—with the task_id in the directory path to write in:
model_path = '/tmp/keras-model'

def _is_chief(task_type, task_id):
  # Note: there are two possible `TF_CONFIG` configuration.
  #   1) In addition to `worker` tasks, a `chief` task type is use;
  #      in this case, this function should be modified to
  #      `return task_type == 'chief'`.
  #   2) Only `worker` task type is used; in this case, worker 0 is
  #      regarded as the chief. The implementation demonstrated here
  #      is for this case.
  # For the purpose of this Colab section, the `task_type is None` case
  # is added because it is effectively run with only a single worker.
  return (task_type == 'worker' and task_id == 0) or task_type is None

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)

With that, you're now ready to save:

multi_worker_model.save(write_model_path)
2021-08-20 01:22:24.305980: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras-model/assets
INFO:tensorflow:Assets written to: /tmp/keras-model/assets

As described above, later on the model should only be loaded from the path chief saved to, so let's remove the temporary ones the non-chief workers saved:

if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(os.path.dirname(write_model_path))

Now, when it's time to load, let's use convenient tf.keras.models.load_model API, and continue with further work.

Here, assume only using single worker to load and continue training, in which case you do not call tf.keras.models.load_model within another strategy.scope() (note that strategy = tf.distribute.MultiWorkerMirroredStrategy(), as defined earlier):

loaded_model = tf.keras.models.load_model(model_path)

# Now that the model is restored, and can continue with the training.
loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 1s 16ms/step - loss: 2.2960 - accuracy: 0.0000e+00
Epoch 2/2
20/20 [==============================] - 0s 15ms/step - loss: 2.2795 - accuracy: 0.0000e+00
<keras.callbacks.History at 0x7f633b103910>

Checkpoint saving and restoring

On the other hand, checkpointing allows you to save your model's weights and restore them without having to save the whole model.

Here, you'll create one tf.train.Checkpoint that tracks the model, which is managed by the tf.train.CheckpointManager, so that only the latest checkpoint is preserved:

checkpoint_dir = '/tmp/ckpt'

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

Once the CheckpointManager is set up, you're now ready to save and remove the checkpoints the non-chief workers had saved:

checkpoint_manager.save()
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(write_checkpoint_dir)

Now, when you need to restore the model, you can find the latest checkpoint saved using the convenient tf.train.latest_checkpoint function. After restoring the checkpoint, you can continue with training.

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
2021-08-20 01:22:26.176660: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:695] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_INT64
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
      }
      shape {
      }
    }
  }
}

2021-08-20 01:22:26.388321: W tensorflow/core/framework/dataset.cc:679] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch 1/2
20/20 [==============================] - 3s 13ms/step - loss: 2.2948 - accuracy: 0.0000e+00
Epoch 2/2
20/20 [==============================] - 0s 13ms/step - loss: 2.2785 - accuracy: 0.0000e+00
<keras.callbacks.History at 0x7f635d404450>

BackupAndRestore callback

The tf.keras.callbacks.experimental.BackupAndRestore callback provides the fault tolerance functionality by backing up the model and current epoch number in a temporary checkpoint file under backup_dir argument to BackupAndRestore. This is done at the end of each epoch.

Once jobs get interrupted and restart, the callback restores the last checkpoint, and training continues from the beginning of the interrupted epoch. Any partial training already done in the unfinished epoch before interruption will be thrown away, so that it doesn't affect the final model state.

To use it, provide an instance of tf.keras.callbacks.experimental.BackupAndRestore at the Model.fit call.

With MultiWorkerMirroredStrategy, if a worker gets interrupted, the whole cluster pauses until the interrupted worker is restarted. Other workers will also restart, and the interrupted worker rejoins the cluster. Then, every worker reads the checkpoint file that was previously saved and picks up its former state, thereby allowing the cluster to get back in sync. Then, the training continues.

The BackupAndRestore callback uses the CheckpointManager to save and restore the training state, which generates a file called checkpoint that tracks existing checkpoints together with the latest one. For this reason, backup_dir should not be re-used to store other checkpoints in order to avoid name collision.

Currently, the BackupAndRestore callback supports single worker with no strategy, MirroredStrategy, and multi-worker with MultiWorkerMirroredStrategy. Below are two examples for both multi-worker training and single worker training.

# Multi-worker training with MultiWorkerMirroredStrategy
# and the BackupAndRestore callback.

callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
  multi_worker_model = mnist.build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
                       epochs=3,
                       steps_per_epoch=70,
                       callbacks=callbacks)
2021-08-20 01:22:29.530251: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:695] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_INT64
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
      }
      shape {
      }
    }
  }
}
Epoch 1/3
70/70 [==============================] - 3s 12ms/step - loss: 2.2759 - accuracy: 0.1625
Epoch 2/3
70/70 [==============================] - 1s 12ms/step - loss: 2.2146 - accuracy: 0.2761
Epoch 3/3
70/70 [==============================] - 1s 12ms/step - loss: 2.1456 - accuracy: 0.4344
<keras.callbacks.History at 0x7f635d2aac90>

If you inspect the directory of backup_dir you specified in BackupAndRestore, you may notice some temporarily generated checkpoint files. Those files are needed for recovering the previously lost instances, and they will be removed by the library at the end of Model.fit upon successful exiting of your training.

Additional resources

  1. The Distributed training in TensorFlow guide provides an overview of the available distribution strategies.
  2. The Custom training loop with Keras and MultiWorkerMirroredStrategy tutorial shows how to use the MultiWorkerMirroredStrategy with Keras and a custom training loop.
  3. Check out the official models, many of which can be configured to run multiple distribution strategies.
  4. The Better performance with tf.function guide provides information about other strategies and tools, such as the TensorFlow Profiler you can use to optimize the performance of your TensorFlow models.