![]() |
Synchronous training on TPUs and TPU Pods.
Inherits From: Strategy
tf.distribute.TPUStrategy(
tpu_cluster_resolver=None, experimental_device_assignment=None
)
Used in the notebooks
Used in the guide | Used in the tutorials |
---|---|
To construct a TPUStrategy object, you need to run the initialization code as below:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
While using distribution strategies, the variables created within the strategy's scope will be replicated across all the replicas and can be kept in sync using all-reduce algorithms.
To run TF2 programs on TPUs, you can either use .compile
and
.fit
APIs in tf.keras
with TPUStrategy, or write your own customized
training loop by calling strategy.run
directly. Note that
TPUStrategy doesn't support pure eager execution, so please make sure the
function passed into strategy.run
is a tf.function
or
strategy.run
is called inside a tf.function
if eager
behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
distribute_datasets_from_function
and
experimental_distribute_dataset
APIs can be used to distribute the dataset
across the TPU workers when writing your own training loop. If you are using
fit
and compile
methods available in tf.keras.Model
, then Keras will
handle the distribution for you.
An example of writing customized training loop on TPUs:
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(2, input_shape=(5,)),
])
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
def dataset_fn(ctx):
x = np.random.random((2, 5)).astype(np.float32)
y = np.random.randint(2, size=(2, 1))
dataset = tf.data.Dataset.from_tensor_slices((x, y))
return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.distribute_datasets_from_function(
dataset_fn)
iterator = iter(dist_dataset)
@tf.function()
def train_step(iterator):
def step_fn(inputs):
features, labels = inputs
with tf.GradientTape() as tape:
logits = model(features, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
strategy.run(step_fn, args=(next(iterator),))
train_step(iterator)
For the advanced use cases like model parallelism, you can set
experimental_device_assignment
argument when creating TPUStrategy to specify
number of replicas and number of logical devices. Below is an example to
initialize TPU system with 2 logical devices and 1 replica.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 1, 1, 2],
num_replicas=1)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
Then you can run a tf.add
operation only on logical device 0.
@tf.function()
def step_fn(inputs):
features, _ = inputs
output = tf.add(features, features)
# Add operation will be executed on logical device 0.
output = strategy.experimental_assign_to_logical_device(output, 0)
return output
dist_dataset = strategy.distribute_datasets_from_function(
dataset_fn)
iterator = iter(dist_dataset)
strategy.run(step_fn, args=(next(iterator),))
Args | |
---|---|
tpu_cluster_resolver
|
A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. If None, it will assume running on a local TPU worker. |
experimental_device_assignment
|
Optional
tf.tpu.experimental.DeviceAssignment to specify the placement of
replicas on the TPU cluster.
|
Attributes | |
---|---|
cluster_resolver
|
Returns the cluster resolver associated with this strategy.
In general, when using a multi-worker Strategies that intend to have an associated
Single-worker strategies usually do not have a
The
For more information, please see
|
extended
|
tf.distribute.StrategyExtended with additional methods.
|
num_replicas_in_sync
|
Returns number of replicas over which gradients are aggregated. |
Methods
distribute_datasets_from_function
distribute_datasets_from_function(
dataset_fn, options=None
)
Distributes tf.data.Dataset
instances created by calls to dataset_fn
.
The argument dataset_fn
that users pass in is an input function that has a
tf.distribute.InputContext
argument and returns a tf.data.Dataset
instance. It is expected that the returned dataset from dataset_fn
is
already batched by per-replica batch size (i.e. global batch size divided by
the number of replicas in sync) and sharded.
tf.distribute.Strategy.distribute_datasets_from_function
does
not batch or shard the tf.data.Dataset
instance
returned from the input function. dataset_fn
will be called on the CPU
device of each of the workers and each generates a dataset where every
replica on that worker will dequeue one batch of inputs (i.e. if a worker
has two replicas, two batches will be dequeued from the Dataset
every
step).
This method can be used for several purposes. First, it allows you to
specify your own batching and sharding logic. (In contrast,
tf.distribute.experimental_distribute_dataset
does batching and sharding
for you.) For example, where
experimental_distribute_dataset
is unable to shard the input files, this
method might be used to manually shard the dataset (avoiding the slow
fallback behavior in experimental_distribute_dataset
). In cases where the
dataset is infinite, this sharding can be done by creating dataset replicas
that differ only in their random seed.
The dataset_fn
should take an tf.distribute.InputContext
instance where
information about batching and input replication can be accessed.
You can use element_spec
property of the
tf.distribute.DistributedDataset
returned by this API to query the
tf.TypeSpec
of the elements returned by the iterator. This can be used to
set the input_signature
property of a tf.function
. Follow
tf.distribute.DistributedDataset.element_spec
to see an example.
For a tutorial on more usage and properties of this method, refer to the tutorial on distributed input). If you are interested in last partial batch handling, read this section.
Args | |
---|---|
dataset_fn
|
A function taking a tf.distribute.InputContext instance and
returning a tf.data.Dataset .
|
options
|
tf.distribute.InputOptions used to control options on how this
dataset is distributed.
|
Returns | |
---|---|
A tf.distribute.DistributedDataset .
|
experimental_assign_to_logical_device
experimental_assign_to_logical_device(
tensor, logical_device_id
)
Adds annotation that tensor
will be assigned to a logical device.
This adds an annotation to tensor
specifying that operations on
tensor
will be invoked on logical core device id logical_device_id
.
When model parallelism is used, the default behavior is that all ops
are placed on zero-th logical device.
# Initializing TPU system with 2 logical devices and 4 replicas.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 1, 1, 2],
num_replicas=4)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)
@tf.function()
def step_fn(inputs):
output = tf.add(inputs, inputs)
# Add operation will be executed on logical device 0.
output = strategy.experimental_assign_to_logical_device(output, 0)
return output
strategy.run(step_fn, args=(next(iterator),))
Args | |
---|---|
tensor
|
Input tensor to annotate. |
logical_device_id
|
Id of the logical core to which the tensor will be assigned. |
Raises | |
---|---|
ValueError
|
The logical device id presented is not consistent with total number of partitions specified by the device assignment. |
Returns | |
---|---|
Annotated tensor with identical value as tensor .
|
experimental_distribute_dataset
experimental_distribute_dataset(
dataset, options=None
)
Creates tf.distribute.DistributedDataset
from tf.data.Dataset
.
The returned tf.distribute.DistributedDataset
can be iterated over
similar to regular datasets.
NOTE: The user cannot add any more transformations to a
tf.distribute.DistributedDataset
. You can only create an iterator or
examine the tf.TypeSpec
of the data generated by it. See API docs of
tf.distribute.DistributedDataset
to learn more.
The following is an example:
global_batch_size = 2
# Passing the devices is optional.
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
# Create a dataset
dataset = tf.data.Dataset.range(4).batch(global_batch_size)
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def replica_fn(input):
return input*2
result = []
# Iterate over the `tf.distribute.DistributedDataset`
for x in dist_dataset:
# process dataset elements
result.append(strategy.run(replica_fn, args=(x,)))
print(result)
[PerReplica:{
0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
}, PerReplica:{
0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
}]
Three key actions happending under the hood of this method are batching, sharding, and prefetching.
In the code snippet above, dataset
is batched by global_batch_size
, and
calling experimental_distribute_dataset
on it rebatches dataset
to a
new batch size that is equal to the global batch size divided by the number
of replicas in sync. We iterate through it using a Pythonic for loop.
x
is a tf.distribute.DistributedValues
containing data for all replicas,
and each replica gets data of the new batch size.
tf.distribute.Strategy.run
will take care of feeding the right per-replica
data in x
to the right replica_fn
executed on each replica.
Sharding contains autosharding across multiple workers and within every
worker. First, in multi-worker distributed training (i.e. when you use
tf.distribute.experimental.MultiWorkerMirroredStrategy
or tf.distribute.TPUStrategy
), autosharding a dataset over a set of
workers means that each worker is assigned a subset of the entire dataset
(if the right tf.data.experimental.AutoShardPolicy
is set). This is to
ensure that at each step, a global batch size of non-overlapping dataset
elements will be processed by each worker. Autosharding has a couple of
different options that can be specified using
tf.data.experimental.DistributeOptions
. Then, sharding within each worker
means the method will split the data among all the worker devices (if more
than one a present). This will happen regardless of multi-worker
autosharding.