Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge


An multi-worker tf.distribute strategy with parameter servers.

Inherits From: Strategy

Used in the notebooks

Used in the guide Used in the tutorials

Parameter server training is a common data-parallel method to scale up a machine learning model on multiple machines. A parameter server training cluster consists of workers and parameter servers. Variables are created on parameter servers and they are read and updated by workers in each step. By default, workers read and update these variables independently without synchronizing with each other. Under this configuration, it is known as asynchronous training.

In TensorFlow 2, we recommend an architecture based on central coordination for parameter server training. Each worker and parameter server runs a tf.distribute.Server, and on top of that, a coordinator task is responsible for creating resources on workers and parameter servers, dispatching functions, and coordinating the training. The coordinator uses a tf.distribute.experimental.coordinator.ClusterCoordinator to coordinate the cluster, and a tf.distribute.experimental.ParameterServerStrategy to define variables on parameter servers and computation on workers.

For the training to work, the coordinator dispatches tf.functions to be executed on remote workers. Upon receiving requests from the coordinator, a worker executes the tf.function by reading the variables from parameter servers, executing the ops, and updating the variables on the parameter servers. Each of the worker only processes the requests from the coordinator, and communicates with parameter servers, without direct interactions with other workers in the cluster.

As a result, failures of some workers do not prevent the cluster from continuing the work, and this allows the cluster to train with instances that can be occasionally unavailable (e.g. preemptible or spot instances). The coordinator and parameter servers though, must be available at all times for the cluster to make progress.

Note that the coordinator is not one of the training workers. Instead, it creates resources such as variables and datasets, dispatchs tf.functions, saves checkpoints and so on. In addition to workers, parameter servers and the coordinator, an optional evaluator can be run on the side that periodically reads the checkpoints saved by the coordinator and runs evaluations against each checkpoint.

ParameterServerStrategy is supported with two training APIs: Custom Training Loop (CTL) and Keras Training API, also known as CTL is recommended when users prefer to define the details of their training loop, and is recommended when users prefer a high-level abstraction and handling of training.

When using a CTL, ParameterServerStrategy has to work in conjunction with a tf.distribute.experimental.coordinator.ClusterCoordinator object.

When using, currently only the tf.keras.utils.experimental.DatasetCreator input type is supported.

Example code for coordinator

This section provides code snippets that are intended to be run on (the only) one task that is designated as the coordinator. Note that cluster_resolver, variable_partitioner, and dataset_fn arguments are explained in the following "Cluster setup", "Variable partitioning", and "Dataset preparation" sections.

With a CTL,

# Prepare a strategy to use with the cluster and variable partitioning info.
strategy = tf.distribute.experimental.ParameterServerStrategy(
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(

# Prepare a distribute dataset that will place datasets on the workers.
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)

with strategy.scope():
  model = ...
  optimizer, metrics = ...  # Keras optimizer/metrics are great choices
  checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint, checkpoint_dir, max_to_keep=2)
  # `load_checkpoint` infers initial epoch from `optimizer.iterations`.
  initial_epoch = load_checkpoint(checkpoint_manager) or 0

def worker_fn(iterator):

  def replica_fn(inputs):
    batch_data, labels = inputs
    # calculate gradient, applying gradient, metrics update etc., args=(next(iterator),))

for epoch in range(initial_epoch, num_epoch):
  distributed_iterator = iter(distributed_dataset)  # Reset iterator state.
  for step in range(steps_per_epoch):

    # Asynchronously schedule the `worker_fn` to be executed on an arbitrary
    # worker. This call returns immediately.
    coordinator.schedule(worker_fn, args=(distributed_iterator,))

  # `join` blocks until all scheduled `worker_fn`s finish execution. Once it
  # returns, we can read the metrics and save checkpoints as needed.
  coordinator.join()'Metric result: %r', metrics.result())


# Prepare a strategy to use with the cluster and variable partitioning info.
strategy = tf.distribute.experimental.ParameterServerStrategy(

# A dataset function takes a `input_context` and returns a `Dataset`
def dataset_fn(input_context):
  dataset =
  return dataset.repeat().shard(...).batch(...).prefetch(...)

# With ``, a `DatasetCreator` needs to be used.
input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...)

with strategy.scope():
  model = ...  # Make sure the `Model` is created within scope.
model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...)

# Optional callbacks to checkpoint the model, back up the progress, etc.
callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...]

# `steps_per_epoch` is required with `ParameterServerStrategy`., epochs=..., steps_per_epoch=..., callbacks=callbacks)

Example code for worker and parameter servers

In addition to the coordinator, there should be tasks designated as "worker" or "ps". They should run the following code to start a TensorFlow server, waiting for coordinator's requests:

# Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
# the cluster information. See below "Cluster setup" section.
cluster_resolver = ...

server = tf.distribute.Server(

# Blocking the process that starts a server from exiting.

Cluster setup

In order for the tasks in the cluster to know other tasks' addresses, a tf.distribute.cluster_resolver.ClusterResolver is required to be used in coordinator, worker, and ps. The tf.distribute.cluster_resolver.ClusterResolver is responsible for providing the cluster information, as well as the task type and id of the current task. See tf.distribute.cluster_resolver.ClusterResolver for more information.

If TF_CONFIG environment variable is set, a tf.distribute.cluster_resolver.TFConfigClusterResolver should be used as well.

Since there are assumptions in tf.distribute.experimental.ParameterServerStrategy around the naming of the task types, "chief", "ps", and "worker" should be used in the