tf.distribute.TPUStrategy

Synchronous training on TPUs and TPU Pods.

Inherits From: Strategy

Used in the notebooks

Used in the guide

To construct a TPUStrategy object, you need to run the initialization code as below:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

While using distribution strategies, the variables created within the strategy's scope will be replicated across all the replicas and can be kept in sync using all-reduce algorithms.

To run TF2 programs on TPUs, you can either use .compile and .fit APIs in tf.keras with TPUStrategy, or write your own customized training loop by calling strategy.run directly. Note that TPUStrategy doesn't support pure eager execution, so please make sure the function passed into strategy.run is a tf.function or strategy.run is called inside a tf.function if eager behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.

experimental_distribute_datasets_from_function and experimental_distribute_dataset APIs can be used to distribute the dataset across the TPU workers when writing your own training loop. If you are using fit and compile methods available in tf.keras.Model, then Keras will handle the distribution for you.

An example of writing customized training loop on TPUs:

with strategy.scope():
  model = tf.keras.Sequential([
    tf.keras.layers.Dense(2, input_shape=(5,)),
  ])
  optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
def dataset_fn(ctx):
  x = np.random.random((2, 5)).astype(np.float32)
  y = np.random.randint(2, size=(2, 1))
  dataset = tf.data.Dataset.from_tensor_slices((x, y))
  return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.experimental_distribute_datasets_from_function(
    dataset_fn)
iterator = iter(dist_dataset)
@tf.function()
def train_step(iterator):

  def step_fn(inputs):
    features, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(features, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

  strategy.run(step_fn, args=(next(iterator),))
train_step(iterator)

For the advanced use cases like model parallelism, you can set experimental_device_assignment argument when creating TPUStrategy to specify number of replicas and number of logical devices. Below is an example to initialize TPU system with 2 logical devices and 1 replica.

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
    topology,
    computation_shape=[1, 1, 1, 2],
    num_replicas=1)
strategy = tf.distribute.TPUStrategy(
    resolver, experimental_device_assignment=device_assignment)

Then you can run a tf.add operation only on logical device 0.

@tf.function()
def step_fn(inputs):
  features, _ = inputs
  output = tf.add(features, features)

  # Add operation will be executed on logical device 0.
  output = strategy.experimental_assign_to_logical_device(output, 0)
  return output
dist_dataset = strategy.experimental_distribute_datasets_from_function(
    dataset_fn)
iterator = iter(dist_dataset)
strategy.run(step_fn, args=(next(iterator),))

tpu_cluster_resolver A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. If None, it will assume running on a local TPU worker.
experimental_device_assignment Optional tf.tpu.experimental.DeviceAssignment to specify the placement of replicas on the TPU cluster.

cluster_resolver Returns the cluster resolver associated with this strategy.

In general, when using a multi-worker tf.distribute strategy such as tf.distribute.experimental.MultiWorkerMirroredStrategy or tf.distribute.experimental.TPUStrategy(), there is a tf.distribute.cluster_resolver.ClusterResolver associated with the strategy used, and such an instance is returned by this property.

Strategies that intend to have an associated tf.distribute.cluster_resolver.ClusterResolver must set the relevant attribute, or override this property; otherwise, None is returned by default. Those strategies should also provide information regarding what is returned by this property.

Single-worker strategies usually do not have a tf.distribute.cluster_resolver.ClusterResolver, and in those cases this property will return None.

The