Stay organized with collections Save and categorize content based on your preferences.

TensorFlow 2 version View source on GitHub

Additional APIs for algorithms that need to be distribution-aware.

Inherits From: StrategyExtended

Lower-level concepts:

  • Wrapped values: In order to represent values parallel across devices (either replicas or the devices associated with a particular value), we wrap them in a "PerReplica" or "Mirrored" object that contains a map from replica id to values. "PerReplica" is used when the value may be different across replicas, and "Mirrored" when the value are the same.
  • Unwrapping and merging: Consider calling a function fn on multiple replicas, like experimental_run_v2(fn, args=[w]) with an argument w that is a wrapped value. This means w will have a map taking replica id 0 to w0, replica id 11 to w1, etc. experimental_run_v2() unwraps w before calling fn, so it calls fn(w0) on d0, fn(w1) on d1, etc. It then merges the return values from fn(), which can possibly result in wrapped values. For example, let's say fn() returns a tuple with three components: (x, a, v0) from replica 0, (x, b, v1) on replica 1, etc. If the first component is the same object x from every replica, then the first component of the merged result will also be x. If the second component is different (a, b, ...) from each replica, then the merged value will have a wrapped map from replica device to the different values. If the third component is the members of a mirrored variable (v maps d0 to v0, d1 to v1, etc.), then the merged result will be that mirrored variable (v).
  • Worker devices vs. parameter devices: Most replica computations will happen on worker devices. Since we don't yet support model parallelism, there will be one worker device per replica. When using parameter servers or central storage, the set of devices holding variables may be different, otherwise the parameter devices might match the worker devices.

Replica context vs. Cross-replica context

replica context is when we are in some function that is being called once for each replica. Otherwise we are in cross-replica context, which is useful for calling tf.distribute.Strategy methods which operate across the replicas (like reduce_to()). By default you start in a replica context (the "default single replica context") and then some methods can switch you back and forth. There is a third mode you can be in called update context used when updating variables.

In a replica context, you may freely read the values of variables, but you may only update their value if they specify a way to aggregate the update using the aggregation parameter in the variable's constructor. In a cross-replica context, you may read or write variables (writes may need to be broadcast to all copies of the variable if it is mirrored).

Sync on read variables

In some cases, such as a metric, we want to accumulate a bunch of updates on each replica independently and only aggregate when reading. This can be a big performance win when the value is read only rarely (maybe the value is only read at the end of an epoch or when checkpointing). These are variables created by passing synchronization=ON_READ to the variable's constructor (and some value for aggregation).

The strategy may choose to put the variable on multiple devices, like mirrored variables, but unlike mirrored variables we don't synchronize the updates to them to make sure they have the same value. Instead, the synchronization is performed when reading in cross-replica context. In a replica context, reads and writes are performed on the local copy (we allow reads so you can write code like v = 0.9*v + 0.1*update). We don't allow operations like v.assign_add in a cross-replica context for sync on read variables; right now we don't have a use case for such updates and depending on the aggregation mode such updates may not be sensible.


Depending on how a value is produced, it will have a type that will determine how it may be used.

"Per-replica" values exist on the worker devices, with a different value for each replica. They are produced by iterating through a "distributed Dataset" returned by tf.distribute.Strategy.experimental_distribute_dataset and tf.distribute.Strategy.experimental_distribute_datasets_from_function. They are also the typical result returned by tf.distribute.Strategy.experimental_run_v2. You typically can't use a per-replica value directly in a cross-replica context, without first resolving how to aggregate the values across replicas, for instance by using tf.distribute.Strategy.reduce.

"Mirrored" values are like per-replica values, except we know that the value on all replicas are the same. We can safely read a mirrored value in a cross-replica context by using the value on any replica. You can convert a per-replica value into a mirrored value by using tf.distribute.ReplicaContext.all_reduce.

Values can also have the same locality as a variable, which is a mirrored value but residing on the same devices as the variable (as opposed to the compute devices). Such values may be passed to a call to tf.distribute.StrategyExtended.update to update the value of a variable. You may use tf.distribute.StrategyExtended.colocate_vars_with to give a variable the same locality as another variable. This is useful, for example, for "slot" variables used by an optimizer for keeping track of statistics used to update a primary/model variable. You may convert a per-replica value to a variable's locality by using tf.distribute.StrategyExtended.reduce_to or tf.distribute.StrategyExtended.batch_reduce_to.

In addition to slot variables which should be colocated with their primary variables, optimizers also define non-slot variables. These can be things like "number of step updates performed" or "beta1^t" and "beta2^t". Each strategy has some policy for which devices those variables should be copied too, called the "non-slot devices" (some subset of the parameter devices). We require that all non-slot variables are allocated on the same device, or mirrored across the same set of devices. You can use tf.distribute.StrategyExtended.non_slot_devices to pick a consistent set of devices to pass to both tf.distribute.StrategyExtended.colocate_vars_with and tf.distribute.StrategyExtended.update_non_slot.

How to update a variable

The standard pattern for updating variables is to:

  1. In your function passed to tf.distribute.Strategy.experimental_run_v2, compute a list of (update, variable) pairs. For example, the update might be a the gradient of the loss with respect to the variable.
  2. Switch to cross-replica mode by calling tf.distribute.get_replica_context().merge_call() with the updates and variables as arguments.
  3. Call tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v) (for one variable) or tf.distribute.StrategyExtended.batch_reduce_to (for a list of variables) to sum the updates. and broadcast the result to the variable's devices.
  4. Call tf.distribute.StrategyExtended.update(v) for each variable to update its value.

Steps 2 through 4 are done automatically by class tf.keras.optimizers.Optimizer if you call its tf.keras.optimizers.Optimizer.apply_gradients method in a replica context. They are also done automatically if you call an assign* method on a (non sync-on-read) variable that was constructed with an aggregation method (which is used to determine the reduction used in step 3).

Distribute-aware layers

Layers are generally called in a replica context, except when defining a functional model. tf.distribute.in_cross_replica_context will let you determine which case you are in. If in a replica context, the tf.distribute.get_replica_context function will return a tf.distribute.ReplicaContext object. The ReplicaContext object has an all_reduce method for aggregating across all replicas. Alternatively, you can update variables following steps 2-4 above.

experimental_between_graph Whether the strategy uses between-graph replication or not.

This is expected to return a constant value that will not be changed throughout its life cycle.

experimental_require_static_shapes Returns True if static shape is required; False otherwise.
experimental_should_init Whether initialization is needed.
parameter_devices Returns the tuple of all devices used to place variables.
should_checkpoint Whether checkpointing is needed.
should_save_summary Whether saving summaries is needed.
worker_devices Returns the tuple of all devices used to for compute replica execution.



View source

Combine multiple reduce_to calls into one for faster execution.

reduce_op Reduction type, an instance of tf.distribute.ReduceOp enum.
value_destination_pairs A sequence of (value, destinations) pairs. See reduce_to() for a description.

A list of mirrored values, one per pair in value_destination_pairs.


View source

Mirror a tensor on one device to all worker devices.

tensor A Tensor value to broadcast.
destinations A mirrored variable or device string specifying the destination devices to copy tensor to.

A value mirrored to destinations devices.


View source

Run fn once per replica.

fn may call tf.get_replica_context() to access methods such as replica_id_in_sync_group and merge_call().

merge_call() is used to communicate between the replicas and re-enter the cross-replica context. All replicas pause their execution having encountered a merge_call() call. After that the merge_fn-function is executed. Its results are then unwrapped and given back to each replica call. After that execution resumes until fn is complete or encounters another merge_call(). Example:

# Called once in "cross-replica" context.
def merge_fn(distribution, three_plus_replica_id):
  # sum the values across replicas
  return sum(distribution.experimental_local_results(three_plus_replica_id))

# Called once per replica in `distribution`, in a "replica" context.
def fn(three):
  replica_ctx = tf.get_replica_context()
  v = three + replica_ctx.replica_id_in_sync_group
  # Computes the sum of the `v` values across all replicas.
  s = replica_ctx.merge_call(merge_fn, args=(v,))
  return s + v

with distribution.scope():
  # in "cross-replica" context
  merged_results = distribution.experimental_run_v2(fn, args=[3])
  # merged_results has the values from every replica execution of `fn`.
  # This statement prints a list:

fn function to run (will be run once per replica).
args Tuple or list with positional arguments for fn.
kwargs Dict with keyword arguments for fn.

Merged return value of fn across all replicas.


View source

Scope that controls which devices variables will be created on.

No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.compat.v1.colocate_with() scope).

This may only be used inside self.scope().

Example usage:

with strategy.scope():
  var1 = tf.Variable(...)
  with strategy.extended.colocate_vars_with(var1):
    # var2 and var3 will be created on the same device(s) as var1
    var2 = tf.Variable(...)
    var3 = tf.Variable(...)

  def fn(v1, v2, v3):
    # operates on v1 from var1, v2 from var2, and v3 from var3

  # `fn` runs on every device `var1` is on, `var2` and `var3` will be there
  # too.
  strategy.extended.update(var1, fn, args=(var2, var3))

colocate_with_variable A variable created in this strategy's scope(). Variables created while in the returned context manager will be on the same set of devices as colocate_with_variable.

A context manager.


View source

Makes a dataset for input provided via a numpy array.

This avoids adding numpy_input as a large constant in the graph, and copies the data to the machine or machines that will be processing the input.

numpy_input A nest of NumPy input arrays that will be distributed evenly across all replicas. Note that lists of Numpy arrays are stacked, as that is normal behavior.
session (TensorFlow v1.x graph execution only) A session used for initialization.

A representing numpy_input.


View source

DEPRECATED: please use experimental_run_v2 instead.

Run fn with input from iterator for iterations times.

This method can be used to run a step function for training a number of times using input from a dataset.

fn function to run using this distribution strategy. The function must have the following signature: def fn(context, inputs). context is an instance of MultiStepContext that will be passed when fn is run. context can be used to specify the outputs to be returned from fn by calling context.set_last_step_output. It can also be used to capture non tensor outputs by context.set_non_tensor_output. See MultiStepContext documentation for more information. inputs will have same type/structure as iterator.get_next(). Typically, fn will use call_for_each_replica method of the strategy to distribute the computation over multiple replicas.
iterator Iterator of a dataset that represents the input for