تاریخ را ذخیره کنید! Google I / O 18-20 مه بازمی گردد اکنون ثبت نام کنید

tf.distribute.experimental.MultiWorkerMirroredStrategy

A distribution strategy for synchronous training on multiple workers.

Inherits From: MultiWorkerMirroredStrategy, Strategy

Used in the notebooks

Used in the tutorials

This strategy implements synchronous distributed training across multiple workers, each with potentially multiple GPUs. Similar to tf.distribute.MirroredStrategy, it replicates all variables and computations to each local device. The difference is that it uses a distributed collective implementation (e.g. all-reduce), so that multiple workers can work together.

You need to launch your program on each worker and configure cluster_resolver correctly. For example, if you are using tf.distribute.cluster_resolver.TFConfigClusterResolver, each worker needs to have its corresponding task_type and task_id set in the TF_CONFIG environment variable. An example TF_CONFIG on worker-0 of a two worker cluster is:

TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

Your program runs on each worker as-is. Note that collectives require each worker to participate. All tf.distribute and non tf.distribute API may use collectives internally, e.g. checkpointing and saving since reading a tf.Variable with tf.VariableSynchronization.ON_READ all-reduces the value. Therefore it's recommended to run exactly the same program on each worker. Dispatching based on task_type or task_id of the worker is error-prone.

cluster_resolver.num_accelerators() determines the number of GPUs the strategy uses. If it's zero, the strategy uses the CPU. All workers need to use the same number of devices, otherwise the behavior is undefined.

This strategy is not intended for TPU. Use tf.distribute.TPUStrategy instead.

After setting up TF_CONFIG, using this strategy is similar to using tf.distribute.MirroredStrategy and tf.distribute.TPUStrategy.

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
  model = tf.keras.Sequential([
    tf.keras.layers.Dense(2, input_shape=(5,)),
  ])
  optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)

def dataset_fn(ctx):
  x = np.random.random((2, 5)).astype(np.float32)
  y = np.random.randint(2, size=(2, 1))
  dataset = tf.data.Dataset.from_tensor_slices((x, y))
  return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)

model.compile()
model.fit(dist_dataset)

You can also write your own training loop:

@tf.function
def train_step(iterator):

  def step_fn(inputs):
    features, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(features, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

  strategy.run(step_fn, args=(next(iterator),))

for _ in range(NUM_STEP):
  train_step(iterator)

See Multi-worker training with Keras for a detailed tutorial.

Saving

You need to save and checkpoint on all workers instead of just one. This is because variables whose synchronization=ON_READ triggers aggregation during saving. It's recommended to save to a different path on each worker to avoid race conditions. Each worker saves the same thing. See Multi-worker training with Keras tutorial for examples.

Known Issues

communication optional tf.distribute.experimental.CommunicationImplementation. This is a hint on the preferred collective communication implementation. Possible values include AUTO, RING, and NCCL.
cluster_resolver optional tf.distribute.cluster_resolver.ClusterResolver. If None, tf.distribute.cluster_resolver.TFConfigClusterResolver is used.

cluster_resolver Returns the cluster resolver associated with this strategy.

As a multi-worker strategy, tf.distribute.experimental.MultiWorkerMirroredStrategy provides the associated tf.distribute.cluster_resolver.ClusterResolver. If the user provides one in __init__, that instance is returned; if the user does not, a default TFConfigClusterResolver is provided.

extended tf.distribute.StrategyExtended with additional methods.
num_replicas_in_sync Returns number of replicas over which gradients are aggregated.

Methods

distribute_datasets_from_function

View source

Distributes tf.data.Dataset instances created by calls to dataset_fn.

The argument dataset_fn that users pass in is an input function that has a tf.distribute.InputContext argument and returns a tf.data.Dataset instance. It is expected that the returned dataset from dataset_fn is already batched by per-replica batch size (i.e. global batch size divided by the number of replicas in sync) and sharded. tf.distribute.Strategy.distribute_datasets_from_function does not batch or shard the tf.data.Dataset instance returned from the input function. dataset_fn will be called on the CPU device of each of the workers and each generates a dataset where every replica on that worker will dequeue one batch of inputs (i.e. if a worker has two replicas, two batches will be dequeued from the Dataset every step).

This method can be used for several purposes. First, it allows you to specify your own batching and sharding logic. (In contrast, tf.distribute.experimental_distribute_dataset does batching and sharding for you.) For example, where experimental_distribute_dataset is unable to shard the input files, this method might be used to manually shard the dataset (avoiding the slow fallback behavior in experimental_distribute_dataset). In cases where the dataset is infinite, this sharding can be done by creating dataset replicas that differ only in their random seed.

The dataset_fn should take an tf.distribute.InputContext instance where information about batching and input replication can be accessed.

You can use element_spec property of the tf.distribute.DistributedDataset returned by this API to query the tf.TypeSpec of the elements returned by the iterator. This can be used to set the input_signature property of a tf.function. Follow tf.distribute.DistributedDataset.element_spec to see an example.

For a tutorial on more usage and properties of this method, refer to the tutorial on distributed input). If you are interested in last partial batch handling, read this section.

Args
dataset_fn A function taking a tf.distribute.InputContext instance and returning a tf.data.Dataset.
options tf.distribute.InputOptions used to control options on how this dataset is distributed.

Returns
A tf.distribute.DistributedDataset.

experimental_distribute_dataset

View source

Creates tf.distribute.DistributedDataset from tf.data.Dataset.

The returned tf.distribute.DistributedDataset can be iterated over similar to regular datasets. NOTE: The user cannot add any more transformations to a tf.distribute.DistributedDataset. You can only create an iterator or examine the tf.TypeSpec of the data generated by it. See API docs of tf.distribute.DistributedDataset to learn more.

The following is an example:

global_batch_size = 2
# Passing the devices is optional.
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
# Create a dataset
dataset = tf.data.Dataset.range(4).batch(global_batch_size)
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def replica_fn(input):
  return input*2
result = []
# Iterate over the `tf.distribute.DistributedDataset`
for x in dist_dataset:
  # process dataset elements
  result.append(strategy.run(replica_fn, args=(x,)))
print(result)
[PerReplica:{
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
}, PerReplica:{
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
}]

Three key actions happending under the hood of this method are batching, sharding, and prefetching.

In the code snippet above, dataset is batched by global_batch_size, and calling experimental_distribute_dataset on it rebatches dataset to a new batch size that is equal to the global batch size divided by the number of replicas in sync. We iterate through it using a Pythonic for loop. x is a tf.distribute.DistributedValues containing data for all replicas, and each replica gets data of the new batch size. tf.distribute.Strategy.run will take care of feeding the right per-replica data in x to the right replica_fn