|View source on GitHub|
Additional APIs for algorithms that need to be distribution-aware.
tf.compat.v1.distribute.StrategyExtended( container_strategy )
Some common use cases of functions on this page:
tf.distribute.DistributedValues can have the same locality as a
distributed variable, which leads to a mirrored value residing on the same
devices as the variable (as opposed to the compute devices). Such values may
be passed to a call to
tf.distribute.StrategyExtended.update to update the
value of a variable. You may use
tf.distribute.StrategyExtended.colocate_vars_with to give a variable the
same locality as another variable. You may convert a "PerReplica" value to a
variable's locality by using
- How to update a distributed variable
A distributed variable is variables created on multiple devices. As discussed in the glossary, mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to:
- In your function passed to
tf.distribute.Strategy.run, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable.
- Switch to cross-replica mode by calling
tf.distribute.get_replica_context().merge_call()with the updates and variables as arguments.
tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)(for one variable) or
tf.distribute.StrategyExtended.batch_reduce_to(for a list of variables) to sum the updates.
tf.distribute.StrategyExtended.update(v)for each variable to update its value.
Steps 2 through 4 are done automatically by class
tf.keras.optimizers.Optimizer if you call its