|TensorFlow 1 version||View source on GitHub|
A state & compute distribution policy on a list of devices.
tf.distribute.Strategy( extended )
- To use it with Keras
fit, please read.
- You may pass descendant of
tf.estimator.RunConfigto specify how a
tf.estimator.Estimatorshould distribute its computation. See guide.
- Otherwise, use
tf.distribute.Strategy.scopeto specify that a strategy should be used when building an executing your model. (This puts you in the "cross-replica context" for this strategy, which means the strategy is put in control of things like variable placement.)
If you are writing a custom training loop, you will need to call a few more methods, see the guide:
- Start by creating a
tf.distribute.Strategy.experimental_distribute_datasetto convert a
tf.data.Datasetto something that produces "per-replica" values. If you want to manually specify how the dataset should be partitioned across replicas, use
tf.distribute.Strategy.runto run a function once per replica, taking values that may be "per-replica" (e.g. from a
tf.distribute.DistributedDatasetobject) and returning "per-replica" values. This function is executed in "replica context", which means each operation is performed separately on each replica.
Finally use a method (such as
tf.distribute.Strategy.reduce) to convert the resulting "per-replica" values into ordinary
- Start by creating a
A custom training loop can be as simple as:
with my_strategy.scope(): @tf.function def distribute_train_epoch(dataset): def replica_fn(input): # process input and return result return result total_result = 0 for x in dataset: per_replica_result = my_strategy.run(replica_fn, args=(x,)) total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_result, axis=None) return total_result dist_dataset = my_strategy.experimental_distribute_dataset(dataset) for _ in range(EPOCHS): train_result = distribute_train_epoch(dist_dataset)
This takes an ordinary
replica_fn and runs it
distributed using a particular
my_strategy above. Any variables created in
replica_fn are created
my_strategy's policy, and library functions called by
replica_fn can use the
get_replica_context() API to implement
You can use the
reduce API to aggregate results across replicas and use
this as a return value from one iteration over a
you can use
tf.keras.metrics (such as loss, accuracy, etc.) to
accumulate metrics across steps in a given epoch.
See the custom training loop tutorial for a more detailed example.
Returns the cluster resolver associated with this strategy.
In general, when using a multi-worker
Strategies that intend to have an associated
Single-worker strategies usually do not have a
For more information, please see
||Returns number of replicas over which gradients are aggregated.|
distribute_datasets_from_function( dataset_fn, options=None )
tf.data.Dataset instances created by calls to
dataset_fn that users pass in is an input function that has a
tf.distribute.InputContext argument and returns a
instance. It is expected that the returned dataset from
already batched by per-replica batch size (i.e. global batch size divided by
the number of replicas in sync) and sharded.
not batch or shard the
returned from the input function.
dataset_fn will be called on the CPU
device of each of the workers and each generates a dataset where every
replica on that worker will dequeue one batch of inputs (i.e. if a worker
has two replicas, two batches will be dequeued from the
This method can be used for several p