Parameter Server Training

View on Run in Google Colab View source on GitHub Download notebook


Parameter server training is a common data-parallel method to scale up model training on multiple machines. A parameter server training cluster consists of workers and parameter servers. Variables are created on parameter servers and they are read and updated by workers in each step. By default, workers read and update these variables independently without synchronizing with each other. This is why sometimes parameter server-style training is called asynchronous training.

TensorFlow 2 parameter server training uses a central-coordinator via the tf.distribute.experimental.coordinator.ClusterCoordinator class.

In this implementation the worker and parameter server tasks run tf.distribute.Servers that listen for requests from the coordinator. The coordinator creates resources, dispatches training tasks, writes checkpoints, and deals with task failures.

We believe this architecture and the new ClusterCoordinator class provide a more flexible and simpler programming model.


The ClusterCoordinator class needs to work in conjunction with a tf.distribute.Strategy object. This tf.distribute.Strategy object is needed to pass the information of the cluster and is used to define a training step as we have seen in custom training with MirroredStrategy. The ClusterCoordinator object then dispatches the execution of these training steps to remote workers. Currently, the ClusterCoordinator only works with tf.distribute.experimental.ParameterServerStrategy.

The most important API provided by the ClusterCoordinator object is schedule. The schedule enqueues a tf.function and returns a future-like RemoteValue immediately. The queued functions will be dispatched to remote workers in background threads and their RemoteValues will be filled asynchronously. Since schedule doesn’t require worker assignment, the tf.function passed in can be executed on any available worker. If the worker it is executed on becomes unavailable before its completion, the function will be retried on another available worker. Because of this fact and the fact that function execution is not atomic, a function may be executed more than once.

In addition to dispatching remote functions, the ClusterCoordinator also helps to create datasets on all the workers and rebuild these datasets when a worker recovers from failure.

Tutorial Setup

# This tutorial requires TensorFlow 2.4
pip install -q --pre -U tensorflow
ERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.

We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.

tensorflow-metadata 0.25.0 requires absl-py<0.11,>=0.9, but you'll have absl-py 0.11.0 which is incompatible.

pip install -q portpicker
import multiprocessing
import os
import random
import portpicker
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers.experimental.preprocessing as kpl

Cluster Setup

As mentioned above, a parameter server training cluster requires a coordinator task that runs your training program, one or several workers and parameter server tasks that run TensorFlow servers, i.e. tf.distribute.Server, and possibly an additional evaluation task that runs side-car evaluation. The requirements to set them up are:

  • The coordinator task needs to know the addresses and ports of all other TensorFlow servers except the evaluator.
  • The workers and parameter servers need to know which port they need to listen to. For the sake of simplicity, we usually pass in the complete cluster information when we create TensorFlow servers on these tasks.
  • The evaluator task doesn’t have to know the setup of the training cluster. If it does, it should not attempt to connect to the training cluster.
  • Workers and parameter servers should have task types as “worker” and “ps” respectively. The coordinator should use “chief” as the task type for legacy reasons.

In this tutorial, we will create an in-process cluster so that the whole parameter server training can be run in colab. We will introduce how to set up real clusters in a later section.

In-process cluster

In this tutorial, we will start a bunch of TensorFlow servers in advance and connect to them later:

def create_in_process_cluster(num_workers, num_ps):
  """Creates and starts local servers and returns the cluster_resolver."""
  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]

  cluster_dict = {}
  cluster_dict["worker"] = ["localhost:%s" % port for port in worker_ports]
  if num_ps > 0:
    cluster_dict["ps"] = ["localhost:%s" % port for port in ps_ports]

  cluster_spec = tf.train.ClusterSpec(cluster_dict)

  # Workers need some inter_ops threads to work properly.
  worker_config = tf.compat.v1.ConfigProto()
  if multiprocessing.cpu_count() < num_workers + 1:
    worker_config.inter_op_parallelism_threads = num_workers + 1

  for i in range(num_workers):
        cluster_spec, job_name="worker", task_index=i, config=worker_config,

  for i in range(num_ps):
        cluster_spec, job_name="ps", task_index=i, protocol="grpc")

  cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
      cluster_spec, rpc_layer="grpc")
  return cluster_resolver

# Set the environment variable to allow reporting worker and ps failure to the
# coordinator. This is a workaround and won't be necessary in the future.
os.environ["GRPC_FAIL_FAST"] = "use_caller"

NUM_PS = 2
cluster_resolver = create_in_process_cluster(NUM_WORKERS, NUM_PS)

Training with Custom Training Loop

Custom training loop with tf.distribute.Strategy provides great flexibility to define training loops. Currently for parameter server training in TensorFlow 2, only custom training loop is supported. Here we use ParameterServerStrategy to define a training step and then use ClusterCoordinator to dispatch the execution of training steps to remote workers.

Create the ParameterServerStrategy

To write a training step in custom training loop, the first step is to create a ParameterServerStrategy. We will explain the variable_partitioner later.

variable_partitioner = (

strategy = tf.distribute.experimental.ParameterServerStrategy(
INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:localhost/replica:0/task:0/device:CPU:0'], variable_device = '/job:localhost/replica:0/task:0/device:CPU:0'
INFO:tensorflow:`tf.distribute.experimental.ParameterServerStrategy` is initialized with cluster_spec: ClusterSpec({'ps': ['localhost:16710', 'localhost:24338'], 'worker': ['localhost:22086', 'localhost:19715', 'localhost:24687']})
INFO:tensorflow:ParameterServerStrategyV2 is now connecting to cluster with cluster_spec: ClusterSpec({'ps': ['localhost:16710', 'localhost:24338'], 'worker': ['localhost:22086', 'localhost:19715', 'localhost:24687']})

Then you will create a model, define a dataset and a step function as we have seen in the training loop with other tf.distribute.Strategys. You can find more details in this tutorial. Let’s create these components in following steps:

Setup the data

First, write a function that creates a dataset that includes preprocessing logic implemented by Keras preprocessing layers. We will these layers outside the dataset_fn but apply the transformation inside the dataset_fn since you will wrap the dataset_fn into a tf.function which doesn't allow variables to be created inside it.

feature_vocab = [
    "avenger", "ironman", "batman", "hulk", "spiderman", "kingkong",
label_vocab = ["yes", "no"]

with strategy.scope():
  feature_lookup_layer = kpl.StringLookup(vocabulary=feature_vocab)

  label_lookup_layer = kpl.StringLookup(vocabulary=label_vocab,

  raw_feature_input = keras.layers.Input(
      shape=(3,), dtype=tf.string, name="feature")
  feature_id_input = feature_lookup_layer(raw_feature_input)
  feature_preprocess_stage = keras.Model(
      {"features": raw_feature_input}, feature_id_input)

  raw_label_input = keras.layers.Input(
      shape=(1,), dtype=tf.string, name="label")
  label_id_input = label_lookup_layer(raw_label_input)
  label_preprocess_stage = keras.Model({"label": raw_label_input}, label_id_input)

Generate toy examples in a dataset:

def feature_and_label_gen(num_examples=200):
  examples = {"features": [], "label": []}
  for _ in range(num_examples):
    features = random.sample(feature_vocab, 3)
    label = ["yes"] if "avenger" in features else ["no"]
  return examples

examples = feature_and_label_gen()

Then we create the training dataset wrapped in a dataset_fn:

def dataset_fn(_):
  raw_dataset =

  train_dataset =
      lambda x: (
          {"features": feature_preprocess_stage(x["features"])},
  return train_dataset

Build the model

Second, we create the model and other objects. Make sure to create all variables under strategy.scope.

# These variables created under the `strategy.scope` will be placed on parameter
# servers in a round-robin fashion.
with strategy.scope():
  # Create the model. The input needs to be compatible with KPLs.
  model_input = keras.layers.Input(
      shape=(3,), dtype=tf.int64, name="model_input")

  emb_layer = keras.layers.Embedding(
      input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=20)
  emb_output = tf.reduce_mean(emb_layer(model_input), axis=1)
  dense_output = keras.layers.Dense(units=1, activation="sigmoid")(emb_output)
  model = keras.Model({"features": model_input}, dense_output)

  optimizer = keras.optimizers.RMSprop(learning_rate=0.1)
  accuracy = keras.metrics.Accuracy()

Define the training step

Third, create the training step wrapped into a tf.function:

def step_fn(iterator):

  def replica_fn(iterator):
    batch_data, labels = next(iterator)
    with tf.GradientTape() as tape:
      pred = model(batch_data, training=True)
      per_example_loss = keras.losses.BinaryCrossentropy(
              reduction=tf.keras.losses.Reduction.NONE)(labels, pred)
      loss = tf.nn.compute_average_loss(per_example_loss)
      gradients = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
    accuracy.update_state(labels, actual_pred)
    return loss

  losses =, args=(iterator,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)

In the above step function, calling and strategy.reduce in the step_fn are useful to support GPUs or multiple replicas worker in the future, although they have trivial implementation at this moment.

Dispatch training steps to remote workers

After all the computations are defined by ParameterServerStrategy, we will use the ClusterCoordinator class to create resources and distribute the training steps to remote workers.

Let’s first create a ClusterCoordinator object and pass in the strategy object:

coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)

Then we create a per-worker dataset and an iterator. In the per_worker_dataset_fn below, wrapping the dataset_fn into strategy.distribute_datasets_from_function is optional but it will allow supporting efficient prefetch to GPUs seamlessly in the future when GPUs are supported by ParameterServerStrategy.

def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)

per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.string, name='feature'), name='feature', description="created by layer 'feature'"), but it was called on an input with incompatible shape (3,).

The final step is to distribute the computation to remote workers using schedule. The schedule method enqueues a tf.function and returns a future-like RemoteValue immediately. The queued functions will be dispatched to remote workers in background threads and the RemoteValue will be filled asynchronously. The join method can be used to wait until all scheduled functions are excuted.

num_epoches = 4
steps_per_epoch = 5
for i in range(num_epoches):
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  # Wait at epoch boundaries.
  print ("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:1/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:1/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:1/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:1/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/job:ps/replica:0/task:0/device:CPU:0',).
Finished epoch 0, accuracy is 0.650000.
Finished epoch 1, accuracy is 1.000000.
Finished epoch 2, accuracy is 1.000000.
Finished epoch 3, accuracy is 1.000000.

Here is how you can fetch the result of a RemoteValue:

loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print ("Final loss is %f" % loss.fetch())
Final loss is 0.002676

Alternatively, you can launch all steps and do something while waiting for completion:

for _ in range(total_steps):
  coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
  # Do something like logging metrics or writing checkpoints.

For the complete training and serving workflow for this particular example, please check out this test.

More about dataset creation

The dataset in the above code is created using the create_per_worker_dataset API. It creates one dataset per worker and returns a container object. You can call iter method on it to create a per-worker iterator. The per-worker iterator contains one iterator per worker and the corresponding slice of a worker will be substituted in the input argument of the function passed to the schedule method before the function is executed on a particular worker.

Currently the schedule method assumes workers are equivalent and thus assumes the datasets on different workers are the same except they may be shuffled differently if they contain a dataset.shuffle operation. Because of this, we also recommend the datasets to be repeated indefinitely and schedule a finite number of steps instead of relying on the OutOfRangeError from a dataset.

Another important note is that datasets don’t support implicit serialization and deserialization across task boundaries. So it is important to create the whole dataset inside the function passed to create_per_worker_dataset.

Variable sharding

Variable sharding refers to splitting a variable into multiple smaller variables. We call these smaller variables shards. Variable sharding may be useful to distribute the network load when accessing these shards. It is also useful to distribute computation and storage of a normal variable across multiple parameter servers.

To enable variable sharding, you can pass in a variable_partitioner when constructing a ParameterServerStrategy object. The variable_partitioner will be invoked every time when a variable is created and it is expected to return the number of shards along each dimension of the variable. Some out-of-box variable_partitioners are provided such as tf.distribute.experimental.partitioners.FixedShardsPartitioner.

In the above example, we use the FixedShardsPartitioner which will split all variables into two shards and each shard will be assigned to different parameter servers:

assert len(emb_layer.weights) == 2
assert emb_layer.weights[0].shape == (5, 20)
assert emb_layer.weights[1].shape == (4, 20)
assert emb_layer.weights[0].device == "/job:ps/replica:0/task:0/device:CPU:0"
assert emb_layer.weights[1].device == "/job:ps/replica:0/task:1/device:CPU:0"

When a variable_partitioner is passed in and if you create a variable directly under strategy.scope(), it will become a container type with a variables property which provides access to the list of shards. In most cases, this container will be automatically converted to a Tensor by concatenating all the shards. As a result, it can be used as a normal variable. On the other hand, some TensorFlow methods such as tf.nn.embedding_lookup provide efficient implementation for this container type and in these methods automatic concatenation will be avoided.

Please see the API docstring of ParameterServerStrategy for more details.


There are more than one way to define and run an evaluation loop in distributed training.

Side-car evaluation

One method is called side-car evaluation which is to create a dedicated evaluator task that repeatedly reads checkpoints and runs evaluation on a latest checkpoint. Following is a possible side-car evaluation loop:

checkpoint_dir = ...
eval_model = ...
eval_data = ...
checkpoint = tf.train.Checkpoint(model=eval_model)

for latest_checkpoint in tf.train.checkpoints_iterator(
  except (tf.errors.OpError,) as e:
    # checkpoint may be deleted by training when it is about to read it.

  # Optionally add callbacks to write summaries.

  # Evaluation finishes when it has evaluated the last epoch.
  if latest_checkpoint.endswith('-{}'.format(train_epoches)):

Inline evaluation

In this method the coordinator alternates between training and evaluation and thus we call it inline evaluation. There are several benefits of inline evaluation. For example, it can support large evaluation models and evaluation datasets that a single task cannot hold. For another example, the evaluation results can be used to make decisions for training next epoch.

There are two ways to implement inline evaluation:

  • Direct evaluation - For small models and evaluation datasets the coordinator can run evaluation directly on the distributed model with the evaluation dataset on the coordinator:
eval_dataset =
          lambda x: (
              {"features": feature_preprocess_stage(x["features"])},

eval_accuracy = keras.metrics.Accuracy()
for batch_data, labels in eval_dataset:
  pred = model(batch_data, training=False)
  actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
  eval_accuracy.update_state(labels, actual_pred)

print ("Evaluation accuracy: %f" % eval_accuracy.result())
WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.string, name='feature'), name='feature', description="created by layer 'feature'"), but it was called on an input with incompatible shape (3,).
Evaluation accuracy: 1.000000

  • Distributed evaluation - For large models or datasets that are infeasible to run directly on the coordinator, the coordinator task can distribute evaluation tasks to the workers via the schedule/join methods.

Clusters in real-world

In a real production environment, you will run all tasks in different processes on different machines. The simplest way to configure cluster information on each task is to set "TF_CONFIG" environment variables and use TFConfigClusterResolver to parse "TF_CONFIG". For a general description about "TF_CONFIG" environment variables, please see the distribuetd training guide.

If you start your training tasks using Kubernetes or other configuration templates, it is very likely that these templates have already set “TF_CONFIG” for you.

Set “TF_CONFIG” environment variable

Suppose you have 3 workers and 2 parameter servers, the “TF_CONFIG” of worker 1 can be:

os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["host1:port", "host2:port", "host3:port"],
        "ps": ["host4:port", "host5:port"],
        "chief": ["host6:port"]
   "task": {"type": "worker", "index": 1}

The “TF_CONFIG” of the evaluator can be:

os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "evaluator": ["host7:port"]
   "task": {"type": "evaluator", "index": 0}

The “cluster” part in the above “TF_CONFIG” string for the evaluator is optional.

If you use the same binary for all tasks

If you prefer to run all these tasks using a single binary, you will need to let your program branch into different roles at the very beginning:

cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
if cluster_resolver.task_type in ("worker", "ps"):
  # start a TensorFlow server and wait.
elif cluster_resolver.task_type == "evaluator":
  # run side-car evaluation
  # run the coordinator.

The following code starts a TensorFlow server and waits:

# Set the environment variable to allow reporting worker and ps failure to the
# coordinator. This is a workaround and won't be necessary in the future.
os.environ["GRPC_FAIL_FAST"] = "use_caller"

cluster_resolver = tf.distribute.cluster_resolver.TF_ConfigClusterResolver()
server = tf.distribute.Server(
    protocol=cluster_resolver.rpc_layer or "grpc",

Handling Task Failure

Worker failure

As mentioned above, the ClusterCoordinator has built-in fault tolerance for worker failure. Upon worker recovery, the corresponding slice of datasets created by create_per_worker_dataset that are still in scope will be recreated by invoking its original dataset_fn passed to create_per_worker_dataset.

Parameter server or the coordinator failure

However, when the coordinator sees a parameter server error, it will raise an UnavailableError or AbortedError immediately. You can restart the coordinator in this case. The coordinator itself can also become unavailable. Therefore, in order to not lose much of the training progress, it is important to checkpoint the model variables periodically and load model variables from a checkpoint, if any, before training starts. The training progress can be inferred approximately from optimizer.iterations if an optimizer is checkpointed.

checkpoint_manager = tf.train.CheckpointManager(
    tf.train.Checkpoint(model=model, optimizer=optimizer),
if checkpoint_manager.latest_checkpoint:
  checkpoint = checkpoint_manager.checkpoint

global_steps = int(optimizer.iterations.numpy())
starting_epoch = global_steps // steps_per_epoch

for _ in range(starting_epoch, num_epoches):
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))

Fetching a RemoteValue

Fetching a RemoteValue is guaranteed to succeed if a function is executed successfully. This is because currently the return value is immediately copied to the coordinator after a function is executed. If there is any worker failure during the copy, the function will be retried on another available worker. Therefore, if you want to optimize for performance, you can schedule functions without a return value.

Error Reporting

Once the coordinator sees an error such as UnavailableError from parameter servers or other application errors such as an InvalidArgument from tf.debugging.check_numerics, it will cancel all pending and queued functions before raising the error. Fetching their corresponding RemoteValues will raise a CancelledError.

After an error is raised, the coordinator will not raise the same error or any error from cancelled functions.

Performance Improvement

There are several possible reasons if you see performance issues when you train with ParameterServerStrategy and ClusterResolver.

One common reason is parameter servers have unbalanced load and some heavily-loaded parameter servers have reached capacity. There can also be multiple root causes. Some simple methods to mitigate this issue are to

  1. shard your large model variables via specifying a variable_partitioner when constructing a ParameterServerStrategy.
  2. avoid creating a hotspot variable that is required by all parameter servers in a single step if possible. For example, use a constant learning rate or subclass tf.keras.optimizers.schedules.LearningRateSchedule in optimizers since the default behavior is that the learning rate will become a variable placed on a particular parameter server and requested by all other parameter servers in each step.
  3. shuffle your large vocabularies before passing them to Keras preprocessing layers.

Another possible reason for performance issues is the coordinator. Our first implementation of schedule/join is Python-based and thus may have threading overhead. Also the latency between the coordinator and the workers can be large. If this is the case, you can pack multiple steps into a single tf.function:

steps_per_invocation = 10
def step_fn(iterator):
  for _ in range(steps_per_invocation):
    features, labels = next(iterator)
    def replica_fn(features, labels):
      ..., args=(features, labels))

We will keep optimizing the coordinator and hopefully most users don’t have to manually pack steps in the future.

In addition, a small trick for performance improvement is to schedule functions without a return value as explained in the handling task failure section above.