Synchronous training on TPUs and TPU Pods.

To construct a TPUStrategy object, you need to run the initialization code as below:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
strategy = tf.distribute.TPUStrategy(resolver)

While using distribution strategies, the variables created within the strategy's scope will be replicated across all the replicas and can be kept in sync using all-reduce algorithms.

To run TF2 programs on TPUs, you can either use .compile and .fit APIs in tf.keras with TPUStrategy, or write your own customized training loop by calling directly. Note that TPUStrategy doesn't support pure eager execution, so please make sure the function passed into is a tf.function or is called inside a tf.function if eager behavior is enabled. See more details in

experimental_distribute_datasets_from_function and experimental_distribute_dataset APIs can be used to distribute the dataset across the TPU workers when writing your own training loop. If you are using fit and compile methods available in tf.keras.Model, then Keras will handle the distribution for you.

An example of writing customized training loop on TPUs:

with strategy.scope():
  model = tf.keras.Sequential([
    tf.keras.layers.Dense(2, input_shape=(5,)),
  optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
def dataset_fn(ctx):
  x = np.random.random((2, 5)).astype(np.float32)
  y = np.random.randint(2, size=(2, 1))
  dataset =, y))
  return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.experimental_distribute_datasets_from_function(
iterator = iter(dist_dataset)
def train_step(iterator):

  def step_fn(inputs):
    features, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(features, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables)), args=(next(iterator),))

For the advanced use cases like model parallelism, you can set experimental_device_assignment argument when creating TPUStrategy to specify number of replicas and number of logical devices. Below is an example to initialize TPU system with 2 logical devices and 1 replica.

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment =