|TensorFlow 1 version||View source on GitHub|
An multi-worker tf.distribute strategy with parameter servers.
tf.distribute.experimental.ParameterServerStrategy( cluster_resolver, variable_partitioner=None )
Used in the notebooks
|Used in the tutorials|
Parameter server training is a common data-parallel method to scale up a machine learning model on multiple machines. A parameter server training cluster consists of workers and parameter servers. Variables are created on parameter servers and they are read and updated by workers in each step. By default, workers read and update these variables independently without synchronizing with each other. Under this configuration, it is known as asynchronous training.
In TensorFlow 2, we recommend an architecture based on central coordination
for parameter server training. Each worker and parameter server runs a
tf.distribute.Server, and on top of that, a coordinator task is responsible
for creating resources on workers and parameter servers, dispatching
functions, and coordinating the training. The coordinator uses a
tf.distribute.experimental.coordinator.ClusterCoordinator to coordinate the
cluster, and a
tf.distribute.experimental.ParameterServerStrategy to define
variables on parameter servers and computation on workers.
For the training to work, the coordinator dispatches
tf.functions to be
executed on remote workers. Upon receiving requests from the coordinator, a
worker executes the
tf.function by reading the variables from parameter
servers, executing the ops, and updating the variables on the parameter
servers. Each of the worker only processes the requests from the coordinator,
and communicates with parameter servers, without direct interactions with
other workers in the cluster.
As a result, failures of some workers do not prevent the cluster from continuing the work, and this allows the cluster to train with instances that can be occasionally unavailable (e.g. preemptible or spot instances). The coordinator and parameter servers though, must be available at all times for the cluster to make progress.
Note that the coordinator is not one of the training workers. Instead, it
creates resources such as variables and datasets, dispatchs
saves checkpoints and so on. In addition to workers, parameter servers and
the coordinator, an optional evaluator can be run on the side that
periodically reads the checkpoints saved by the coordinator and runs
evaluations against each checkpoint.
ParameterServerStrategy is supported with two training APIs: Custom
Training Loop (CTL)
and Keras Training API, also known as
Model.fit. CTL is recommended
when users prefer to define the details of their training loop, and
Model.fit is recommended when users prefer a high-level abstraction and
handling of training.
When using a CTL,
ParameterServerStrategy has to work in conjunction with a
Example code for coordinator
This section provides code snippets that are intended to be run on (the only)
one task that is designated as the coordinator. Note that
dataset_fn arguments are explained in the
following "Cluster setup", "Variable partitioning", and "Dataset preparation"
With a CTL,
# Prepare a strategy to use with the cluster and variable partitioning info. strategy = tf.distribute.experimental.ParameterServerStrategy( cluster_resolver=..., variable_partitioner=...) coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( strategy=strategy) # Prepare a distribute dataset that will place datasets on the workers. distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...) with strategy.scope(): model = ... optimizer, metrics = ... # Keras optimizer/metrics are great choices checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, checkpoint_dir, max_to_keep=2) # `load_checkpoint` infers initial epoch from `optimizer.iterations`. initial_epoch = load_checkpoint(checkpoint_manager) or 0 @tf.function def worker_fn(iterator): def replica_fn(inputs): batch_data, labels = inputs # calculate gradient, applying gradient, metrics update etc. strategy.run(replica_fn, args=(next(iterator),)) for epoch in range(initial_epoch, num_epoch): distributed_iterator = iter(distributed_dataset) # Reset iterator state. for step in range(steps_per_epoch): # Asynchronously schedule the `worker_fn` to be executed on an arbitrary # worker. This call returns immediately. coordinator.schedule(worker_fn, args=(distributed_iterator,)) # `join` blocks until all scheduled `worker_fn`s finish execution. Once it # returns, we can read the metrics and save checkpoints as needed. coordinator.join() logging.info('Metric result: %r', metrics.result()) train_accuracy.reset_states() checkpoint_manager.save()
# Prepare a strategy to use with the cluster and variable partitioning info. strategy = tf.distribute.experimental.ParameterServerStrategy( cluster_resolver=..., variable_partitioner=...) # A dataset function takes a `input_context` and returns a `Dataset` def dataset_fn(input_context): dataset = tf.data.Dataset.from_tensors(...) return dataset.repeat().shard(...).batch(...).prefetch(...) # With `Model.fit`, a `DatasetCreator` needs to be used. input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...) with strategy.scope(): model = ... # Make sure the `Model` is created within scope. model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...) # Optional callbacks to checkpoint the model, back up the progress, etc. callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...] # `steps_per_epoch` is required with `ParameterServerStrategy`. model.fit(input, epochs=..., steps_per_epoch=..., callbacks=callbacks)
Example code for worker and parameter servers
In addition to the coordinator, there should be tasks designated as "worker" or "ps". They should run the following code to start a TensorFlow server, waiting for coordinator's requests:
# Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves # the cluster information. See below "Cluster setup" section. cluster_resolver = ... server = tf.distribute.Server( cluster_resolver.cluster_spec(), job_name=cluster_resolver.task_type, task_index=cluster_resolver.task_id, protocol="grpc") # Blocking the process that sta