Registration is open for TensorFlow Dev Summit 2020 Learn more

Use a TPU

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Experimental support for Cloud TPUs is currently available for Keras and Google Colab. Before you run this Colab notebooks, ensure that your hardware accelerator is a TPU by checking your notebook settings: Runtime > Change runtime type > Hardware accelerator > TPU.

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import os
import tensorflow_datasets as tfds

Distribution strategies

This guide demonstrates how to use the distribution strategy tf.distribute.experimental.TPUStrategy to drive a Cloud TPU and train a Keras model. A distribution strategy is an abstraction that can be used to drive models on CPU, GPUs or TPUs. Simply swap out the distribution strategy and the model will run on the given device. See the distribution strategy guide for more information.

Below is the code that connects to a TPU and creates the TPUStrategy object. Note that the tpu argument to TPUClusterResolver is a special address just for Colab. In the case that you are running on Google Compute Engine (GCE), you should instead pass in the name of your CloudTPU.

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
INFO:tensorflow:Initializing the TPU system: 10.240.1.2:8470

INFO:tensorflow:Initializing the TPU system: 10.240.1.2:8470

INFO:tensorflow:Clearing out eager caches

INFO:tensorflow:Clearing out eager caches

INFO:tensorflow:Finished initializing TPU system.

INFO:tensorflow:Finished initializing TPU system.

<tensorflow.python.tpu.topology.Topology at 0x7f74e49f9d30>

Below is a simple MNIST model, unchanged from what you would use on CPU or GPU.

def create_model():
  return tf.keras.Sequential(
      [tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(128, activation='relu'),
       tf.keras.layers.Dense(10)])

Input datasets

Efficient use of the tf.data.Dataset API is critical when using a Cloud TPU, as it is impossible to use the Cloud TPUs unless you can feed them data quickly enough. See Input Pipeline Performance Guide for details on dataset performance.

For all but the simplest experimentation (using tf.data.Dataset.from_tensor_slices or other in-graph data) you will need to store all data files read by the Dataset in Google Cloud Storage (GCS) buckets.

For most use-cases, it is recommended to convert your data into TFRecord format and use a tf.data.TFRecordDataset to read it. See TFRecord and tf.Example tutorial for details on how to do this. This, however, is not a hard requirement and you can use other dataset readers (FixedLengthRecordDataset or TextLineDataset) if you prefer.

Small datasets can be loaded entirely into memory using tf.data.Dataset.cache.

Regardless of the data format used, it is strongly recommended that you use large files, on the order of 100MB. This is especially important in this networked setting as the overhead of opening a file is significantly higher.

Here you should use the tensorflow_datasets module to get a copy of the MNIST training data. Note that try_gcs is specified to use a copy that is available in a public GCS bucket. If you don't specify this, the TPU will not be able to access the data that is downloaded.

def get_dataset(batch_size=200):
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
                             try_gcs=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0

    return image, label

  train_dataset = mnist_train.map(scale).shuffle(10000).batch(batch_size)
  test_dataset = mnist_test.map(scale).batch(batch_size)

  return train_dataset, test_dataset

Create and train a model

Nothing here is TPU specific, you would write the same code below if you had mutliple GPUs and where using a MirroredStrategy rather than a TPUStrategy.

strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])
  
train_dataset, test_dataset = get_dataset()

model.fit(train_dataset,
          epochs=5,
          validation_data=test_dataset)
INFO:tensorflow:Found TPU system:

INFO:tensorflow:Found TPU system:

INFO:tensorflow:*** Num TPU Cores: 8

INFO:tensorflow:*** Num TPU Cores: 8

INFO:tensorflow:*** Num TPU Workers: 1

INFO:tensorflow:*** Num TPU Workers: 1

INFO:tensorflow:*** Num TPU Cores Per Worker: 8

INFO:tensorflow:*** Num TPU Cores Per Worker: 8

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

Epoch 1/5
300/300 [==============================] - 15s 50ms/step - loss: 0.1986 - sparse_categorical_accuracy: 0.9435 - val_loss: 0.0798 - val_sparse_categorical_accuracy: 0.9759
Epoch 2/5
300/300 [==============================] - 5s 17ms/step - loss: 0.0552 - sparse_categorical_accuracy: 0.9836 - val_loss: 0.0564 - val_sparse_categorical_accuracy: 0.9820
Epoch 3/5
300/300 [==============================] - 5s 16ms/step - loss: 0.0341 - sparse_categorical_accuracy: 0.9896 - val_loss: 0.0465 - val_sparse_categorical_accuracy: 0.9842
Epoch 4/5
300/300 [==============================] - 5s 18ms/step - loss: 0.0202 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.0442 - val_sparse_categorical_accuracy: 0.9855
Epoch 5/5
300/300 [==============================] - 5s 17ms/step - loss: 0.0117 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.0493 - val_sparse_categorical_accuracy: 0.9842

<tensorflow.python.keras.callbacks.History at 0x7f74e006ef98>

Next steps