|TensorFlow 1 version||View source on GitHub|
Additional APIs for algorithms that need to be distribution-aware.
tf.distribute.StrategyExtended( container_strategy )
Some common use cases of functions on this page:
tf.distribute.DistributedValues can have the same locality as a
distributed variable, which leads to a mirrored value residing on the same
devices as the variable (as opposed to the compute devices). Such values may
be passed to a call to
tf.distribute.StrategyExtended.update to update the
value of a variable. You may use
tf.distribute.StrategyExtended.colocate_vars_with to give a variable the
same locality as another variable. You may convert a "PerReplica" value to a
variable's locality by using
- How to update a distributed variable
A distributed variable is variables created on multiple devices. As discussed in the glossary, mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to:
- In your function passed to
tf.distribute.Strategy.run, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable.
- Switch to cross-replica mode by calling
tf.distribute.get_replica_context().merge_call()with the updates and variables as arguments.
tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)(for one variable) or
tf.distribute.StrategyExtended.batch_reduce_to(for a list of variables) to sum the updates.
tf.distribute.StrategyExtended.update(v)for each variable to update its value.
In fact, a higher-level solution to update a distributed variable is by
assign on the variable as you would do to a regular
You can call the method in both replica context and cross-replica context.
For a mirrored variable, calling
assign in replica context requires you
to specify the
aggregation type in the variable constructor. In that case,
the context switching and sync described in steps 2 through 4 are handled for
you. If you call
assign on mirrored variable in cross-replica context,
you can only assign a single value or assign values from another mirrored
variable or a mirrored
tf.distribute.DistributedValues. For a SyncOnRead
variable, in replica context, you can simply call
assign on it and no
aggregation happens under the hood. In cross-replica context, you can only
assign a single value to a SyncOnRead variable. One example case is restoring
from a checkpoint: if the
aggregation type of the variable is
tf.VariableAggregation.SUM, it is assumed that replica values were added
before checkpointing, so at the time of restoring, the value is divided by
the number of replicas and then assigned to each replica; if the
tf.VariableAggregation.MEAN, the value is assigned to each replica
||Returns the tuple of all devices used to place variables.|
||Returns the tuple of all devices used to for compute replica execution.|
batch_reduce_to( reduce_op, value_destination_pairs, options=None )
reduce_to calls into one for faster execution.
reduce_to, but accepts a list of (value, destinations) pairs.
It's more efficient than reduce each value separately.
This API currently can only be called in cross-replica context. Other variants to reduce values across replicas are:
tf.distribute.StrategyExtended.reduce_to: the non-batch version of this API.
tf.distribute.ReplicaContext.all_reduce: the counterpart of this API in replica context. It supports both batched and non-batched all-reduce.
tf.distribute.Strategy.reduce: a more convenient method to reduce to the host in cross-replica context.
reduce_to for more information.
def merge_fn(strategy, value, var):
# All-reduce the value. Note that `value` here is a
reduced = strategy.extended.batch_reduce_to(
tf.distribute.ReduceOp.SUM, [(value, var)])
strategy.extended.update(var, lambda var, value: var.assign(value),
value = tf.identity(1.)
v = tf.Variable(0.)
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
a sequence of (value, destinations) pairs. See