|View source on GitHub|
Additional APIs for algorithms that need to be distribution-aware.
tf.compat.v1.distribute.StrategyExtended( container_strategy )
- Wrapped values: In order to represent values parallel across devices (either replicas or the devices associated with a particular value), we wrap them in a "PerReplica" or "Mirrored" object that contains a map from replica id to values. "PerReplica" is used when the value may be different across replicas, and "Mirrored" when the value are the same.
- Unwrapping and merging: Consider calling a function
fnon multiple replicas, like
run(fn, args=[w])with an argument
wthat is a wrapped value. This means
wwill have a map taking replica id
w0, replica id
fn, so it calls
d1, etc. It then merges the return values from
fn(), which can possibly result in wrapped values. For example, let's say
fn()returns a tuple with three components:
(x, a, v0)from replica 0,
(x, b, v1)on replica 1, etc. If the first component is the same object
xfrom every replica, then the first component of the merged result will also be
x. If the second component is different (
b, ...) from each replica, then the merged value will have a wrapped map from replica device to the different values. If the third component is the members of a mirrored variable (
v1, etc.), then the merged result will be that mirrored variable (
- Worker devices vs. parameter devices: Most replica computations will happen on worker devices. Since we don't yet support model parallelism, there will be one worker device per replica. When using parameter servers or central storage, the set of devices holding variables may be different, otherwise the parameter devices might match the worker devices.
Replica context vs. Cross-replica context
A replica context applies when we are in some function that is being called
once for each replica. Otherwise we are in cross-replica context, which is
useful for calling
tf.distribute.Strategy methods which operate across the
reduce_to()). By default you start in a replica context
(the "default single replica context") and then some methods can switch you
back and forth. There is a third mode you can be in called update context
used when updating variables.
tf.distribute.Strategy.scope: enters cross-replica context when no other strategy is in scope.
tf.distribute.Strategy.run: calls a function in replica context.
tf.distribute.ReplicaContext.merge_call: transitions from replica context to cross-replica context.
tf.distribute.StrategyExtended.update: calls a function in an update context from a cross-replica context.
In a replica context, you may freely read the values of variables, but
you may only update their value if they specify a way to aggregate the
update using the
aggregation parameter in the variable's constructor.
In a cross-replica context, you may read or write variables (writes may
need to be broadcast to all copies of the variable if it is mirrored).
Sync on read variables
In some cases, such as a metric, we want to accumulate a bunch of updates on
each replica independently and only aggregate when reading. This can be a big
performance win when the value is read only rarely (maybe the value is only
read at the end of an epoch or when checkpointing). These are variables
created by passing
synchronization=ON_READ to the variable's constructor
(and some value for
The strategy may choose to put the variable on multiple devices, like mirrored
variables, but unlike mirrored variables we don't synchronize the updates to
them to make sure they have the same value. Instead, the synchronization is
performed when reading in cross-replica context. In a replica context, reads
and writes are performed on the local copy (we allow reads so you can write
v = 0.9*v + 0.1*update). We don't allow operations like
v.assign_add in a cross-replica context for sync on read variables; right
now we don't have a use case for such updates and depending on the aggregation
mode such updates may not be sensible.
Depending on how a value is produced, it will have a type that will determine how it may be used.
"Per-replica" values exist on the worker devices, with a different value for
each replica. They are produced by iterating through a "distributed
are also the typical result returned by
tf.distribute.Strategy.run. You typically can't use a
per-replica value directly in a cross-replica context, without first resolving
how to aggregate the values across replicas, for instance by using
"Mirrored" values are like per-replica values, except we know that the value
on all replicas are the same. We can safely read a mirrored value in a
cross-replica context by using the value on any replica. You can convert
a per-replica value into a mirrored value by using
Values can also have the same locality as a variable, which is a mirrored
value but residing on the same devices as the variable (as opposed to the
compute devices). Such values may be passed to a call to
tf.distribute.StrategyExtended.update to update the value of a variable.
You may use
tf.distribute.StrategyExtended.colocate_vars_with to give a
variable the same locality as another variable. This is useful, for example,
for "slot" variables used by an optimizer for keeping track of statistics
used to update a primary/model variable. You may convert a per-replica
value to a variable's locality by using
In addition to slot variables which should be colocated with their primary
variables, optimizers also define non-slot variables. These can be things like
"number of step updates performed" or "beta1^t" and "beta2^t". Each strategy
has some policy for which devices those variables should be copied too, called
the "non-slot devices" (some subset of the parameter devices). We require that
all non-slot variables are allocated on the same device, or mirrored across
the same set of devices. You can use
tf.distribute.StrategyExtended.non_slot_devices to pick a consistent set of
devices to pass to both
How to update a variable
The standard pattern for updating variables is to:
- In your function passed to
tf.distribute.Strategy.run, compute a list of (update, variable) pairs. For example, the update might be a the gradient of the loss with respect to the variable.
- Switch to cross-replica mode by calling
tf.distribute.get_replica_context().merge_call()with the updates and variables as arguments.
tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)(for one variable) or
tf.distribute.StrategyExtended.batch_reduce_to(for a list of variables) to sum the updates. and broadcast the result to the variable's devices.
tf.distribute.StrategyExtended.update(v)for each variable to update its value.
Steps 2 through 4 are done automatically by class
tf.keras.optimizers.Optimizer if you call its
tf.keras.optimizers.Optimizer.apply_gradients method in a replica context.
They are also done automatically if you call an
assign* method on a (non
sync-on-read) variable that was constructed with an aggregation method (which
is used to determine the reduction used in step 3).
Layers are generally called in a replica context, except when defining a
tf.distribute.in_cross_replica_context will let you
determine which case you are in. If in a replica context,
tf.distribute.get_replica_context function will return a
tf.distribute.ReplicaContext object. The
ReplicaContext object has an
all_reduce method for aggregating across all replicas. Alternatively, you
can update variables following steps 2-4 above.
Whether the strategy uses between-graph replication or not.
This is expected to return a constant value that will not be changed throughout its life cycle.
||Whether initialization is needed.|
||Returns the tuple of all devices used to place variables.|
||Whether checkpointing is needed.|
||Whether saving summaries is needed.|
||Returns the tuple of all devices used to for compute replica execution.|
batch_reduce_to( reduce_op, value_destination_pairs, experimental_hints=None )
reduce_to calls into one for faster execution.
Reduction type, an instance of
A sequence of (value, destinations) pairs. See
A list of mirrored values, one per pair in
broadcast_to( tensor, destinations )
Mirror a tensor on one device to all worker devices.
||A Tensor value to broadcast.|
A mirrored variable or device string specifying the
destination devices to copy
A value mirrored to
call_for_each_replica( fn, args=(), kwargs=None )
fn once per replica.
fn may call
tf.get_replica_context() to access methods such as
merge_call() is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a
merge_call() call. After that the
merge_fn-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
fn is complete or encounters another
# Called once in "cross-replica" context. def merge_fn(distribution, three_plus_replica_id): # sum the values across replicas return sum(distribution.experimental_local_results(three_plus_replica_id)) # Called once per replica in `distribution`, in a "replica" context. def fn(three): replica_ctx = tf.get_replica_context() v = three + replica_ctx.replica_id_in_sync_group # Computes the sum of the `v` values across all replicas. s = replica_ctx.merge_call(merge_fn, args=(v,)) return s + v with distribution.scope(): # in "cross-replica" context ... merged_results = distribution.run(fn, args=) # merged_results has the values from every replica execution of `fn`. # This statement prints a list: print(distribution.experimental_local_results(merged_results))
||function to run (will be run once per replica).|
Tuple or list with positional arguments for
Dict with keyword arguments for
Merged return value of