Save the date! Google I/O returns May 18-20 Register now


An multi-worker tf.distribute strategy with parameter servers.

Inherits From: Strategy

Used in the notebooks

Used in the tutorials

Parameter server training is a common data-parallel method to scale up a machine learning model on multiple machines. A parameter server training cluster consists of workers and parameter servers. Variables are created on parameter servers and they are read and updated by workers in each step. By default, workers read and update these variables independently without synchronizing with each other. Under this configuration, it is known as asynchronous training.

In TensorFlow 2, we recommend a central coordiantion-based architecture for parameter server training, where workers and parameter servers run a tf.distribute.Server and there is another task that creates resources on workers and parameter servers, dispatches functions, and coordinates the training. We refer to this task as “coordinator”. The coordinator uses a tf.distribute.experimental.coordinator.ClusterCoordinator to coordinate the cluster, and a tf.distribute.experimental.ParameterServerStrategy to define variables on parameter servers and computation on workers.

For the training to work, the coordinator dispatches tf.functions to be executed on remote workers. Upon receiving requests from the coordinator, a worker executes the tf.function by reading the variables from parameter servers, executing the ops, and updating the variables on the parameter servers. Each of the worker only processes the requests from the coordinator, and communicates with parameter servers, without direct interactions with other workers in the cluster.

As a result, failures of some workers do not prevent the cluster from continuing the work, and this allows the cluster to train with instances that can be occasionally unavailable (e.g. preemptible or spot instances). The coordinator and parameter servers though, must be available at all times for the cluster to make progress.

Note that the coordinator is not one of the training workers. Instead, it creates resources such as variables and datasets, dispatchs tf.functions, saving checkpoints and so on. In addition to workers, parameter servers and the coordinator, an optional evaluator can be run on the side that periodically reads the checkpoints saved by the coordinator and runs evaluations against each checkpoint.

tf.distribute.experimental.ParameterServerStrategy has to work in conjunction with a tf.distribute.experimental.coordinator.ClusterCoordinator object. Standalone usage of tf.distribute.experimental.ParameterServerStrategy without central coordination is not supported at this time.

Example code for coordinator

Here's an example usage of the API, with a custom training loop to train a model. This code snippet is intended to be run on (the only) one task that is designated as the coordinator. Note that cluster_resolver, variable_partitioner, and dataset_fn arguments are explained in the following "Cluster setup", "Variable partitioning", and "Dataset preparation" sections.

# Set the environment variable to allow reporting worker and ps failure to the
# coordinator. This a short-term workaround.
os.environ["GRPC_FAIL_FAST"] = "use_caller"

# Prepare a strategy to use with the cluster and variable partitioning info.
strategy = tf.distribute.experimental.ParameterServerStrategy(
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(

# Prepare a distribute dataset that will place datasets on the workers.
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)

with strategy.scope():
  model = ...
  optimizer, metrics = ...  # Keras optimizer/metrics are great choices
  checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint, checkpoint_dir, max_to_keep=2)
  # `load_checkpoint` infers initial epoch from `optimizer.iterations`.
  initial_epoch = load_checkpoint(checkpoint_manager) or 0

def worker_fn(iterator):

  def replica_fn(inputs):
    batch_data, labels = inputs
    # calculate gradient, applying gradient, metrics update etc., args=(next(iterator),))

for epoch in range(initial_epoch, num_epoch):
  distributed_iterator = iter(distributed_dataset)  # Reset iterator state.
  for step in range(steps_per_epoch):

    # Asynchronously schedule the `worker_fn` to be executed on an arbitrary
    # worker. This call returns immediately.
    coordinator.schedule(worker_fn, args=(distributed_iterator,))

  # `join` blocks until all scheduled `worker_fn`s finish execution. Once it
  # returns, we can read the metrics and save checkpoints as needed.
  coordinator.join()'Metric result: %r', metrics.result())

Example code for worker and parameter servers

In addition to the coordinator, there should be tasks designated as "worker" or "ps". They should run the following code to start a TensorFlow server, waiting for coordinator's requests:

# Set the environment variable to allow reporting worker and ps failure to the
# coordinator.
os.environ["GRPC_FAIL_FAST"] = "use_caller"

# Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
# the cluster information. See below "Cluster setup" section.
cluster_resolver = ...

server = tf.distribute.Server(

# Blocking the process that starts a server from exiting.

Cluster setup

In order for the tasks in the cluster to know other tasks' addresses, a tf.distribute.cluster_resolver.ClusterResolver is required to be used in coordinator, worker, and ps. The tf.distribute.cluster_resolver.ClusterResolver is responsible for providing the cluster information, as well as the task type and id of the current task. See tf.distribute.cluster_resolver.ClusterResolver for more information.

If TF_CONFIG environment variable is set, a tf.distribute.cluster_resolver.TFConfigClusterResolver should be used as well. Note that for legacy reason, on some platform, "chief" is used as the task type for the coordinator, as the following example demonstrates. Here we set TF_CONFIG for the task designated as a parameter server (task type "ps") and index 1 (the second task), in a cluster with 1 chief, 2 parameter servers, and 3 workers. Note that the it needs to be set before the use of tf.distribute.cluster_resolver.TFConfigClusterResolver.

Example code for cluster setup:

os.environ['TF_CONFIG'] = '''
  "cluster": {
    "chief": [""],
    "ps": ["", ""],
    "worker": ["", "",
  "task": {
    "type": "ps",
    "index": 1

If you prefer to run the same binary for all tasks, you will need to let the binary branch into different roles at the beginning of the program:

os.environ["GRPC_FAIL_FAST"] = "use_caller"
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()

# If coordinator, create a strategy and start the training program.
if cluster_resolver.task_type == 'chief':
  strategy = tf.distribute.experimental.ParameterServerStrategy(

# If worker/ps, create a server
elif cluster_resolver.task_type in ("worker", "ps"):
  server = tf.distribute.Server(...)

Alternatively, you can also start a bunch of TensorFlow servers in advance and connect to them later. The coordinator can be in the same cluster or on any machine that has connectivity to workers and parameter server. This is covered in our guide and tutorial.

Variable creation with strategy.scope()

tf.distribute.experimental.ParameterServerStrategy follows the tf.distribute API contract where variable creation is expected to be inside the context manager returned by strategy.scope(), in order to be correctly placed on parameter servers in a round-robin manner:

# In this example, we're assuming having 3 ps.
strategy = tf.distribute.experimental.ParameterServerStrategy(
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(

# Variables should be created inside scope to be placed on parameter servers.
# If created outside scope such as `v1` here, it would be placed on the
# coordinator.
v1 = tf.Variable(initial_value=0.0)

with strategy.scope():
  v2 = tf.Variable(initial_value=1.0)
  v3 = tf.Variable(initial_value=2.0)
  v4 = tf.Variable(initial_value=3.0)
  v5 = tf.Variable(initial_value=4.0)

# v2 through v5 are created in scope and are distributed on parameter servers.
# Default placement is round-robin but the order should not be relied on.
assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"

See distribute.Strategy.scope for more information.

Variable partitioning

Having dedicated servers to store variables means being able to divide up, or "shard" the variables across the ps. Partitioning large variable among ps is a commonly used technique to boost training throughput and mitigate memory constraints. It enables parallel computations and updates on different shards of a variable, and often yields better load balancing across parameter servers . Without sharding, models with large variables (e.g, embeddings) that can't fit into one machine's memory would otherwise be unable to train.

With tf.distribute.experimental.ParameterServerStrategy, if a variable_partitioner is provided to __init__ and certain conditions are satisfied, the resulting variables created in scope are sharded across the parameter servers, in a round-robin fashion. The variable reference returned from tf.Variable becomes a type that serves as the container of the sharded variables. One can access variables attribute of this container for the actual variable components. If building model with tf.Module or Keras, the variable components are collected in the variables alike attributes.

class Dense(tf.Module):
  def __init__(self, name=None):
    self.w = tf.Variable(tf.random.normal([100, 10]), name='w')

  def __call__(self, x):
    return x * self.w

# Partition the dense layer into 2 shards.
variable_partitioiner  = (
    num_shards = 2))
strategy = ParameterServerStrategy(cluster_resolver=...,
  variable_partitioner = variable_partitioner)
with strategy.scope():
  dense = Dense()
assert len(dense.variables) == 2
assert isinstance(dense.variables[0], tf.Variable)
assert isinstance(dense.variables[1], tf.Variable)
assert dense.variables[0].name == "w/part_0"
assert dense.variables[1].name == "w/part_1"

The sharded variable container can be converted to a Tensor via tf.convert_to_tensor. This means the container can be directly used in most Python Ops where such Tensor convertion automatically happens. For example in the above code snippet, x * self.w would implicitly apply the said tensor convertion. Note that such convertion can be expensive, as the variable components need to be transferred from multiple parameter servers to where the value is used.

tf.nn.embedding_lookup on the other hand doesn't apply the tensor convertion , and performs parallel lookups on the variable components instead. This is crutial to scale up embedding lookups when the embedding table variable is large.

When a partitioned variable is saved to SavedModel, it will be saved as if it is one single variable. This improves serving efficiency by eliminating a number of Ops that handle the partiton aspects.

Known limitations of variable partitioning:

  • Number of parttions must not change across Checkpoint save/load.

  • After saving partitioned variables to a SavedModel, the SavedModel can't be loaded via tf.saved_model.load.

  • Partition variable doesn't directly work with tf.GradientTape, please use the variables attributes to get the actual variable components and use them in gradient APIs instead.

Dataset preparation

With tf.distribute.experimental.ParameterServerStrategy, a dataset is created in each of the workers to be used for training. This is done by creating a dataset_fn that takes no argument and returns a, and passing the dataset_fn into tf.distribute.experimental.coordinator. ClusterCoordinator.create_per_worker_dataset. We recommend the dataset to be shuffled and repeated to have the examples run through the training as evenly as possible.

def dataset_fn():
  filenames = ...
  dataset =

  # Dataset is recommended to be shuffled, and repeated.
  return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)

coordinator =
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)


cluster_resolver a tf.distribute.cluster_resolver.ClusterResolver object.
variable_partitioner a distribute.experimental.partitioners.Partitioner that specifies how to partition variables. If None, variables will not be partitioned.

  • Predefined partitioners in tf.distribute.experimental.partitioners can be used for this argument. A commonly used partitioner is MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps), which allocates at least 256K per shard, and each ps gets at most one shard.

  • variable_partitioner will be called for each variable created under strategy scope to instruct how the variable should be partitioned. Variables that have only one partition along the partitioning axis (i.e., no need for partition) will be created as normal tf.Variable.

  • Only the first / outermost axis partitioning is supported.

  • Div partition strategy is used to partition variables. Assuming we assign consecutive integer ids along the first axis of a variable, then ids are assigned to shards in a contiguous manner, while attempting to keep each shard size identical. If the ids do not evenly divide the number of shards, each of the first several shards will be assigned one more id. For instance, a variable whose first dimension is 13 has 13 ids, and they are split across 5 shards as: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].

  • Variables created under strategy.extended.colocate_vars_with will not be partitioned.

cluster_resolver Returns the cluster resolver associated with this strategy.

In general, when using a multi-worker tf.distribute strategy such as tf.distribute.experimental.MultiWorkerMirroredStrategy or tf.distribute.TPUStrategy(), there is a tf.distribute.cluster_resolver.ClusterResolver associated with the strategy used, and such an instance is returned by this property.

Strategies that intend to have an associated tf.distribute.cluster_resolver.ClusterResolver must set the relevant attribute, or override this property; otherwise, None is returned by default. Those strategies should also provide information r