|View source on GitHub|
An asynchronous multi-worker parameter server tf.distribute strategy.
tf.compat.v1.distribute.experimental.ParameterServerStrategy( cluster_resolver=None )
This strategy requires two roles: workers and parameter servers. Variables and updates to those variables will be assigned to parameter servers and other operations are assigned to workers.
When each worker has more than one GPU, operations will be replicated on all GPUs. Even though operations may be replicated, variables are not and each worker shares a common view for which parameter server a variable is assigned to.
By default it uses
TFConfigClusterResolver to detect configurations for
multi-worker training. This requires a 'TF_CONFIG' environment variable and
the 'TF_CONFIG' must have a cluster spec.
This class assumes each worker is running the same code independently, but parameter servers are running a standard server. This means that while each worker will synchronously compute a single gradient update across all GPUs, updates between workers proceed asynchronously. Operations that occur only on the first replica (such as incrementing the global step), will occur on the first replica of every worker.
It is expected to call
call_for_each_replica(fn, ...) for any
operations which potentially can be replicated across replicas (i.e. multiple
GPUs) even if there is only CPU or one GPU. When defining the
caution needs to be taken:
1) It is generally not recommended to open a device scope under the strategy's
scope. A device scope (i.e. calling
tf.device) will be merged with or
override the device for operations but will not change the device for
2) It is also not recommended to open a colocation scope (i.e. calling
tf.compat.v1.colocate_with) under the strategy's scope. For colocating
strategy.extended.colocate_vars_with instead. Colocation of
ops will possibly create device assignment conflicts.
strategy = tf.distribute.experimental.ParameterServerStrategy() run_config = tf.estimator.RunConfig( experimental_distribute.train_distribute=strategy) estimator = tf.estimator.Estimator(config=run_config) tf.estimator.train_and_evaluate(estimator,...)
Returns the cluster resolver associated with this strategy.
In general, when using a multi-worker
Strategies that intend to have an associated
Single-worker strategies usually do not have a
For more information, please see
||Returns number of replicas over which gradients are aggregated.|
experimental_distribute_dataset( dataset, options=None )
The following is an example:
strategy = tf.distribute.MirroredStrategy() # Create a dataset dataset = dataset_ops.Dataset.TFRecordDataset([ "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"]) # Distribute that dataset dist_dataset = strategy.experimental_distribute_dataset(dataset) # Iterate over the `tf.distribute.DistributedDataset` for x in dist_dataset: # process dataset elements strategy.run(replica_fn, args=(x,))
In the code snippet above, the
dist_dataset is batched by
GLOBAL_BATCH_SIZE, and we iterate through it
for x in dist_dataset.
containing data for all replicas, which aggregates to a batch of
tf.distribute.Strategy.run will take care of feeding
the right per-replica data in
x to the right
replica_fn executed on each
What's under the hood of this method, when we say the
dataset - gets distributed? It depends on how you set the
tf.data.experimental.DistributeOptions. By default, it is set to
tf.data.experimental.AutoShardPolicy.AUTO. In a multi-worker setting, we
will first attempt to distribute
dataset by detecting whether
being created out of reader datasets (e.g.
tf.data.TextLineDataset, etc.) and if so, try to shard the input files.
Note that there has to be at least one input file per worker. If you have
less than one input file per worker, we suggest that you disable dataset
sharding across workers, by setting the
tf.data.experimental.DistributeOptions.auto_shard_policy to be
If the attempt to shard by file is unsuccessful (i.e. the dataset is not
read from files), we will shard the dataset evenly at the end by
.shard operation to the end of the processing pipeline. This
will cause the entire preprocessing pipeline for all the data to be run on
every worker, and each worker will do redundant work. We will print a
warning if this route is selected.
As mentioned before, within each worker, we will also split the data among all the worker devices (if more than one a present). This will happen even if multi-worker sharding is disabled.
If the above batch splitting and dataset sharding logic is undesirable,
instead, which does not do any automatic splitting or sharding.
You can also use the
element_spec property of the
tf.distribute.DistributedDataset instance returned by this API to query
tf.TypeSpec of the elements returned
by the iterator. This can be used to set the
strategy = tf.distribute.MirroredStrategy() # Create a dataset dataset = dataset_ops.Dataset.TFRecordDataset([ "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"]) # Distribute that dataset dist_dataset = strategy.experimental_distribute_dataset(dataset) @tf.function(input_signature=[dist_dataset.element_spec]) def train_step(inputs): # train model with inputs return # Iterate over the `tf.distribute.DistributedDataset` for x in dist_dataset: # process dataset elements strategy.run(train_step, args=(x,))
experimental_distribute_datasets_from_function( dataset_fn, options=None )
tf.data.Dataset instances created by calls to
dataset_fn will be called once for each worker in the strategy. Each
replica on that worker will dequeue one batch of inputs from the local
Dataset (i.e. if a worker has two replicas, two batches will be dequeued
Dataset every step).
This method can be used for several purposes. For example, where
experimental_distribute_dataset is unable to shard the input files, this
method might be used to manually shard the dataset (avoiding the slow
fallback behavior in
experimental_distribute_dataset). In cases where the
dataset is infinite, this sharding can be done by creating dataset replicas
that differ only in their random seed.
experimental_distribute_dataset may also sometimes fail to split the
batch across replicas on a worker. In that case, this method can be used
where that limitation does not exist.
dataset_fn should take an
tf.distribute.InputContext instance where
information about batching and input replication can be accessed.