TPUを使用する

TensorFlow.orgで表示GoogleColabで実行GitHubでソースを表示 ノートブックをダウンロード

このColabノートブックを実行する前に、ノートブックの設定を確認して、ハードウェアアクセラレータがTPUであることを確認してください:ランタイム>ランタイムタイプの変更>ハードウェアアクセラレータ> TPU

設定

import tensorflow as tf

import os
import tensorflow_datasets as tfds
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (2.3.0)/charset_normalizer (2.0.11) doesn't match a supported version!
  RequestsDependencyWarning)

TPUの初期化

TPUは通常CloudTPUワーカーであり、ユーザーのPythonプログラムを実行するローカルプロセスとは異なります。したがって、リモートクラスタに接続してTPUを初期化するには、初期化作業を行う必要があります。 tf.distribute.cluster_resolver.TPUClusterResolverへのtpu引数は、 tf.distribute.cluster_resolver.TPUClusterResolver専用の特別なアドレスであることに注意してください。 Google Compute Engine(GCE)でコードを実行している場合は、代わりにCloudTPUの名前を渡す必要があります。

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.10:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.10:8470
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]

デバイスの手動配置

TPUが初期化された後、手動のデバイス配置を使用して、単一のTPUデバイスに計算を配置できます。

a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

with tf.device('/TPU:0'):
  c = tf.matmul(a, b)

print("c device: ", c.device)
print(c)
c device:  /job:worker/replica:0/task:0/device:TPU:0
tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)

流通戦略

通常、データ並列の方法で複数のTPUでモデルを実行します。モデルを複数のTPU(または他のアクセラレータ)に分散するために、TensorFlowはいくつかの分散戦略を提供します。配布戦略を置き換えることができ、モデルは任意の(TPU)デバイスで実行されます。詳細については、配布戦略ガイドを確認してください。

これを示すために、 tf.distribute.TPUStrategyオブジェクトを作成します。

strategy = tf.distribute.TPUStrategy(resolver)
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

すべてのTPUコアで実行できるように計算を複製するには、それをstrategy.runに渡すことができます。以下は、すべてのコアが同じ入力(a, b)を受け取り、各コアで独立して行列乗算を実行することを示す例です。出力は、すべてのレプリカからの値になります。

@tf.function
def matmul_fn(x, y):
  z = tf.matmul(x, y)
  return z

z = strategy.run(matmul_fn, args=(a, b))
print(z)
PerReplica:{
  0: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  1: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  2: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  3: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  4: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  5: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  6: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  7: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)
}

TPUの分類

基本的な概念をカバーしたので、より具体的な例を考えてみましょう。このセクションでは、配布戦略tf.distribute.TPUStrategyを使用してクラウドTPUでKerasモデルをトレーニングする方法を示します。

Kerasモデルを定義する

Kerasを使用したMNISTデータセットでの画像分類のためのSequentialモデルの定義から始めます。 CPUまたはGPUでトレーニングしている場合に使用するものと同じです。 Kerasモデルの作成はstrategy.scope内にある必要があるため、各TPUデバイスで変数を作成できることに注意してください。コードの他の部分は、戦略スコープ内にある必要はありません。

def create_model():
  return tf.keras.Sequential(
      [tf.keras.layers.Conv2D(256, 3, activation='relu', input_shape=(28, 28, 1)),
       tf.keras.layers.Conv2D(256, 3, activation='relu'),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(256, activation='relu'),
       tf.keras.layers.Dense(128, activation='relu'),
       tf.keras.layers.Dense(10)])

データセットをロードする

tf.data.Dataset APIを効率的に使用することは、Cloud TPUを使用する場合に重要です。これは、十分な速度でデータをフィードできない限り、CloudTPUを使用できないためです。データセットのパフォーマンスについて詳しくは、入力パイプラインパフォーマンスガイドをご覧ください。

最も単純な実験( tf.data.Dataset.from_tensor_slicesまたはその他のグラフ内データを使用)を除くすべての場合、データセットによって読み取られたすべてのデータファイルをGoogle Cloud Storage(GCS)バケットに保存する必要があります。

ほとんどのユースケースでは、データをTFRecord形式に変換し、 tf.data.TFRecordDatasetを使用して読み取ることをお勧めします。これを行う方法の詳細については、 TFRecordとtf.Exampleチュートリアルを確認してください。これは難しい要件ではなく、 tf.data.FixedLengthRecordDatasettf.data.TextLineDatasetなどの他のデータセットリーダーを使用できます。

tf.data.Dataset.cacheを使用して、小さなデータセット全体をメモリにロードできます。

使用するデータ形式に関係なく、100MB程度の大きなファイルを使用することを強くお勧めします。これは、ファイルを開くオーバーヘッドが大幅に高くなるため、このネットワーク設定では特に重要です。

以下のコードに示すように、 tensorflow_datasetsモジュールを使用して、MNISTトレーニングおよびテストデータのコピーを取得する必要があります。 try_gcsは、パブリックGCSバケットで利用可能なコピーを使用するように指定されていることに注意してください。これを指定しないと、TPUはダウンロードされたデータにアクセスできなくなります。

def get_dataset(batch_size, is_training=True):
  split = 'train' if is_training else 'test'
  dataset, info = tfds.load(name='mnist', split=split, with_info=True,
                            as_supervised=True, try_gcs=True)

  # Normalize the input data.
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0
    return image, label

  dataset = dataset.map(scale)

  # Only shuffle and repeat the dataset in training. The advantage of having an
  # infinite dataset for training is to avoid the potential last partial batch
  # in each epoch, so that you don't need to think about scaling the gradients
  # based on the actual batch size.
  if is_training:
    dataset = dataset.shuffle(10000)
    dataset = dataset.repeat()

  dataset = dataset.batch(batch_size)

  return dataset

Kerasの高レベルAPIを使用してモデルをトレーニングする

Keras fitを使用してモデルをトレーニングし、APIをcompileできます。このステップにはTPU固有のものはありませんTPUStrategyの代わりに複数のGPUとMirroredStrategyを使用しているかのようにコードを記述します。詳細については、 Kerasチュートリアルを使用した分散トレーニングをご覧ください。

with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])

batch_size = 200
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size

train_dataset = get_dataset(batch_size, is_training=True)
test_dataset = get_dataset(batch_size, is_training=False)

model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=steps_per_epoch,
          validation_data=test_dataset, 
          validation_steps=validation_steps)
Epoch 1/5
300/300 [==============================] - 18s 32ms/step - loss: 0.1433 - sparse_categorical_accuracy: 0.9564 - val_loss: 0.0452 - val_sparse_categorical_accuracy: 0.9859
Epoch 2/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0335 - sparse_categorical_accuracy: 0.9898 - val_loss: 0.0318 - val_sparse_categorical_accuracy: 0.9899
Epoch 3/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0199 - sparse_categorical_accuracy: 0.9935 - val_loss: 0.0397 - val_sparse_categorical_accuracy: 0.9866
Epoch 4/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0109 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.0436 - val_sparse_categorical_accuracy: 0.9892
Epoch 5/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.0481 - val_sparse_categorical_accuracy: 0.9881
<keras.callbacks.History at 0x7f0d485602e8>

Pythonのオーバーヘッドを減らし、TPUのパフォーマンスを最大化するには、引数steps_per_executionModel.compileに渡します。この例では、スループットが約50%向上します。

with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                # Anything between 2 and `steps_per_epoch` could help here.
                steps_per_execution = 50,
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])

model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=steps_per_epoch,
          validation_data=test_dataset,
          validation_steps=validation_steps)
Epoch 1/5
300/300 [==============================] - 12s 41ms/step - loss: 0.1515 - sparse_categorical_accuracy: 0.9537 - val_loss: 0.0416 - val_sparse_categorical_accuracy: 0.9863
Epoch 2/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0366 - sparse_categorical_accuracy: 0.9891 - val_loss: 0.0410 - val_sparse_categorical_accuracy: 0.9875
Epoch 3/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0191 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.0432 - val_sparse_categorical_accuracy: 0.9865
Epoch 4/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.0447 - val_sparse_categorical_accuracy: 0.9875
Epoch 5/5
300/300 [==============================] - 3s 11ms/step - loss: 0.0093 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.0426 - val_sparse_categorical_accuracy: 0.9884
<keras.callbacks.History at 0x7f0d0463cd68>
プレースホルダー16

カスタムトレーニングループを使用してモデルをトレーニングします

tf.functionおよびtf.distributeを直接使用して、モデルを作成およびトレーニングすることもできます。 strategy.experimental_distribute_datasets_from_function APIを使用して、データセット関数を指定してデータセットを配布できます。以下の例では、データセットに渡されるバッチサイズは、グローバルバッチサイズではなく、レプリカごとのバッチサイズであることに注意してください。詳細については、 tf.distribute.Strategyチュートリアルを使用したカスタムトレーニングをご覧ください。

まず、モデル、データセット、tf.functionsを作成します。

# Create the model, optimizer and metrics inside the strategy scope, so that the
# variables can be mirrored on each device.
with strategy.scope():
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)

# Calculate per replica batch size, and distribute the datasets on each TPU
# worker.
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, is_training=True))

@tf.function
def train_step(iterator):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  strategy.run(step_fn, args=(next(iterator),))
WARNING:tensorflow:From <ipython-input-1-5625c2a14441>:15: StrategyBase.experimental_distribute_datasets_from_function (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
rename to distribute_datasets_from_function
WARNING:tensorflow:From <ipython-input-1-5625c2a14441>:15: StrategyBase.experimental_distribute_datasets_from_function (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
rename to distribute_datasets_from_function
プレースホルダー18

次に、トレーニングループを実行します。

steps_per_eval = 10000 // batch_size

train_iterator = iter(train_dataset)
for epoch in range(5):
  print('Epoch: {}/5'.format(epoch))

  for step in range(steps_per_epoch):
    train_step(train_iterator)
  print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))
  training_loss.reset_states()
  training_accuracy.reset_states()
Epoch: 0/5
Current step: 300, training loss: 0.1339, accuracy: 95.79%
Epoch: 1/5
Current step: 600, training loss: 0.0333, accuracy: 98.91%
Epoch: 2/5
Current step: 900, training loss: 0.0176, accuracy: 99.43%
Epoch: 3/5
Current step: 1200, training loss: 0.0126, accuracy: 99.61%
Epoch: 4/5
Current step: 1500, training loss: 0.0122, accuracy: 99.61%

tf.function内の複数のステップによるパフォーマンスの向上

tf.function内で複数のステップを実行することにより、パフォーマンスを向上させることができます。これは、 strategy.run呼び出しをtf.range内のtf.functionでラップすることによって実現され、AutoGraphはそれをTPUワーカーでtf.while_loopに変換します。

パフォーマンスは向上していますが、 tf.function内で単一のステップを実行する場合と比較して、この方法にはトレードオフがあります。 tf.functionで複数のステップを実行することは柔軟性が低く、ステップ内で熱心に実行したり、任意のPythonコードを実行したりすることはできません。

@tf.function
def train_multiple_steps(iterator, steps):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  for _ in tf.range(steps):
    strategy.run(step_fn, args=(next(iterator),))

# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get 
# retraced if the value changes.
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))

print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))
Current step: 1800, training loss: 0.0081, accuracy: 99.74%

次のステップ