|TensorFlow 1 version||View source on GitHub|
A state & compute distribution policy on a list of devices.
tf.distribute.Strategy( extended )
- To use it with Keras
fit, please read.
- You may pass descendant of
tf.estimator.RunConfigto specify how a
tf.estimator.Estimatorshould distribute its computation. See guide.
- Otherwise, use
tf.distribute.Strategy.scopeto specify that a strategy should be used when building an executing your model. (This puts you in the "cross-replica context" for this strategy, which means the strategy is put in control of things like variable placement.)
If you are writing a custom training loop, you will need to call a few more methods, see the guide:
- Start by creating a
tf.distribute.Strategy.experimental_distribute_datasetto convert a
tf.data.Datasetto something that produces "per-replica" values. If you want to manually specify how the dataset should be partitioned across replicas, use
tf.distribute.Strategy.runto run a function once per replica, taking values that may be "per-replica" (e.g. from a
tf.distribute.DistributedDatasetobject) and returning "per-replica" values. This function is executed in "replica context", which means each operation is performed separately on each replica.
Finally use a method (such as
tf.distribute.Strategy.reduce) to convert the resulting "per-replica" values into ordinary
- Start by creating a
A custom training loop can be as simple as:
with my_strategy.scope(): @tf.function def distribute_train_epoch(dataset): def replica_fn(input): # process input and return result return result total_result = 0 for x in dataset: per_replica_result = my_strategy.run(replica_fn, args=(x,)) total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_result, axis=None) return total_result dist_dataset = my_strategy.experimental_distribute_dataset(dataset) for _ in range(EPOCHS): train_result = distribute_train_epoch(dist_dataset)
This takes an ordinary
replica_fn and runs it
distributed using a particular
my_strategy above. Any variables created in
replica_fn are created
my_strategy's policy, and library functions called by
replica_fn can use the
get_replica_context() API to implement
You can use the
reduce API to aggregate results across replicas and use
this as a return value from one iteration over a
you can use
tf.keras.metrics (such as loss, accuracy, etc.) to
accumulate metrics across steps in a given epoch.
See the custom training loop tutorial for a more detailed example.
Returns the cluster resolver associated with this strategy.
In general, when using a multi-worker
Strategies that intend to have an associated
Single-worker strategies usually do not have a
For more information, please see
||Returns number of replicas over which gradients are aggregated.|
distribute_datasets_from_function( dataset_fn, options=None )
tf.data.Dataset instances created by calls to
dataset_fn that users pass in is an input function that has a
tf.distribute.InputContext argument and returns a
instance. It is expected that the returned dataset from
already batched by per-replica batch size (i.e. global batch size divided by
the number of replicas in sync) and sharded.
not batch or shard the
returned from the input function.
dataset_fn will be called on the CPU
device of each of the workers and each generates a dataset where every
replica on that worker will dequeue one batch of inputs (i.e. if a worker
has two replicas, two batches will be dequeued from the
This method can be used for several purposes. First, it allows you to
specify your own batching and sharding logic. (In contrast,
tf.distribute.experimental_distribute_dataset does batching and sharding
for you.) For example, where
experimental_distribute_dataset is unable to shard the input files, this
method might be used to manually shard the dataset (avoiding the slow
fallback behavior in
experimental_distribute_dataset). In cases where the
dataset is infinite, this sharding can be done by creating dataset replicas
that differ only in their random seed.
dataset_fn should take an
tf.distribute.InputContext instance where
information about batching and input replication can be accessed.
You can use
element_spec property of the
tf.distribute.DistributedDataset returned by this API to query the
tf.TypeSpec of the elements returned by the iterator. This can be used to
input_signature property of a