TPU 사용

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

이 Colab 노트북을 실행하기 전에 노트북 설정을 확인하여 하드웨어 가속기가 TPU인지 확인하십시오. 런타임 > 런타임 유형 변경 > 하드웨어 가속기 > TPU .

설정

import tensorflow as tf

import os
import tensorflow_datasets as tfds
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (2.3.0)/charset_normalizer (2.0.11) doesn't match a supported version!
  RequestsDependencyWarning)

TPU 초기화

TPU는 일반적으로 사용자의 Python 프로그램을 실행하는 로컬 프로세스와 다른 Cloud TPU 작업자입니다. 따라서 원격 클러스터에 연결하고 TPU를 초기화하려면 일부 초기화 작업을 수행해야 합니다. tf.distribute.cluster_resolver.TPUClusterResolver에 대한 tpu 인수는 tf.distribute.cluster_resolver.TPUClusterResolver 전용 특수 주소입니다. Google Compute Engine(GCE)에서 코드를 실행하는 경우 대신 Cloud TPU의 이름을 전달해야 합니다.

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.10:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.10:8470
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]

수동 장치 배치

TPU가 초기화된 후 수동 장치 배치를 사용하여 단일 TPU 장치에 계산을 배치할 수 있습니다.

a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

with tf.device('/TPU:0'):
  c = tf.matmul(a, b)

print("c device: ", c.device)
print(c)
c device:  /job:worker/replica:0/task:0/device:TPU:0
tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)

유통 전략

일반적으로 데이터 병렬 방식으로 여러 TPU에서 모델을 실행합니다. 모델을 여러 TPU(또는 다른 가속기)에 배포하기 위해 TensorFlow는 여러 배포 전략을 제공합니다. 배포 전략을 교체할 수 있으며 모델은 지정된(TPU) 장치에서 실행됩니다. 자세한 내용은 배포 전략 가이드 를 확인하세요.

이를 보여주기 위해 tf.distribute.TPUStrategy 객체를 생성합니다:

strategy = tf.distribute.TPUStrategy(resolver)
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

모든 TPU 코어에서 실행할 수 있도록 계산을 복제하려면 이를 strategy.run API에 전달할 수 있습니다. 다음은 모든 코어가 동일한 입력 (a, b) 을 수신하고 각 코어에 대해 독립적으로 행렬 곱셈을 수행하는 예입니다. 출력은 모든 복제본의 값이 됩니다.

@tf.function
def matmul_fn(x, y):
  z = tf.matmul(x, y)
  return z

z = strategy.run(matmul_fn, args=(a, b))
print(z)
PerReplica:{
  0: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  1: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  2: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  3: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  4: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  5: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  6: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  7: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)
}

TPU에 대한 분류

기본 개념을 다뤘으니 좀 더 구체적인 예를 살펴보겠습니다. 이 섹션에서는 배포 전략 tf.distribute.TPUStrategy )을 사용하여 Cloud TPU에서 Keras 모델을 학습시키는 방법을 보여줍니다.

Keras 모델 정의

Keras를 사용하여 MNIST 데이터 세트에서 이미지 분류를 위한 Sequential Keras 모델의 정의로 시작합니다. CPU나 GPU에 대해 훈련할 때 사용하는 것과 다르지 않습니다. Keras 모델 생성은 strategy.scope 내부에 있어야 하므로 각 TPU 기기에서 변수를 생성할 수 있습니다. 코드의 다른 부분은 전략 범위 내에 있을 필요가 없습니다.

def create_model():
  return tf.keras.Sequential(
      [tf.keras.layers.Conv2D(256, 3, activation='relu', input_shape=(28, 28, 1)),
       tf.keras.layers.Conv2D(256, 3, activation='relu'),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(256, activation='relu'),
       tf.keras.layers.Dense(128, activation='relu'),
       tf.keras.layers.Dense(10)])

데이터세트 로드

tf.data.Dataset API의 효율적인 사용은 Cloud TPU에 데이터를 충분히 빠르게 공급할 수 없으면 Cloud TPU를 사용할 수 없기 때문에 Cloud TPU를 사용할 때 매우 중요합니다. 입력 파이프라인 성능 가이드 에서 데이터세트 성능에 대해 자세히 알아볼 수 있습니다.

가장 간단한 실험( tf.data.Dataset.from_tensor_slices 또는 기타 그래프 내 데이터 사용)을 제외한 모든 실험의 경우 데이터 세트에서 읽은 모든 데이터 파일을 Google Cloud Storage(GCS) 버킷에 저장해야 합니다.

대부분의 사용 사례에서 데이터를 TFRecord 형식으로 변환하고 tf.data.TFRecordDataset 을 사용하여 읽는 것이 좋습니다. 이 작업을 수행하는 방법에 대한 자세한 내용은 TFRecord 및 tf.Example 자습서 를 확인하세요. 이것은 어려운 요구 사항이 아니며 tf.data.FixedLengthRecordDataset 또는 tf.data.TextLineDataset 과 같은 다른 데이터 세트 판독기를 사용할 수 있습니다.

tf.data.Dataset.cache 를 사용하여 전체 작은 데이터세트를 메모리에 로드할 수 있습니다.

사용된 데이터 형식에 관계없이 100MB 정도의 대용량 파일을 사용하는 것이 좋습니다. 이것은 파일을 여는 오버헤드가 훨씬 높기 때문에 이 네트워크 설정에서 특히 중요합니다.

아래 코드에서 볼 수 tensorflow_datasets 모듈을 사용하여 MNIST 교육 및 테스트 데이터의 복사본을 가져와야 합니다. try_gcs 는 공개 GCS 버킷에서 사용 가능한 사본을 사용하도록 지정됩니다. 이를 지정하지 않으면 TPU가 다운로드한 데이터에 액세스할 수 없습니다.

def get_dataset(batch_size, is_training=True):
  split = 'train' if is_training else 'test'
  dataset, info = tfds.load(name='mnist', split=split, with_info=True,
                            as_supervised=True, try_gcs=True)

  # Normalize the input data.
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0
    return image, label

  dataset = dataset.map(scale)

  # Only shuffle and repeat the dataset in training. The advantage of having an
  # infinite dataset for training is to avoid the potential last partial batch
  # in each epoch, so that you don't need to think about scaling the gradients
  # based on the actual batch size.
  if is_training:
    dataset = dataset.shuffle(10000)
    dataset = dataset.repeat()

  dataset = dataset.batch(batch_size)

  return dataset

Keras 고급 API를 사용하여 모델 학습

Keras fitcompile API로 모델을 훈련할 수 있습니다. 이 단계에서는 TPU 관련 사항이 없습니다. TPUStrategy 대신 여러 GPU와 MirroredStrategy 를 사용하는 것처럼 코드를 작성합니다. Keras를 사용한 분산 교육 자습서에서 자세히 알아볼 수 있습니다.

with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])

batch_size = 200
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size

train_dataset = get_dataset(batch_size, is_training=True)
test_dataset = get_dataset(batch_size, is_training=False)

model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=steps_per_epoch,
          validation_data=test_dataset, 
          validation_steps=validation_steps)
Epoch 1/5
300/300 [==============================] - 18s 32ms/step - loss: 0.1433 - sparse_categorical_accuracy: 0.9564 - val_loss: 0.0452 - val_sparse_categorical_accuracy: 0.9859
Epoch 2/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0335 - sparse_categorical_accuracy: 0.9898 - val_loss: 0.0318 - val_sparse_categorical_accuracy: 0.9899
Epoch 3/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0199 - sparse_categorical_accuracy: 0.9935 - val_loss: 0.0397 - val_sparse_categorical_accuracy: 0.9866
Epoch 4/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0109 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.0436 - val_sparse_categorical_accuracy: 0.9892
Epoch 5/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.0481 - val_sparse_categorical_accuracy: 0.9881
<keras.callbacks.History at 0x7f0d485602e8>

Python 오버헤드를 줄이고 TPU의 성능을 최대화하려면 인수 steps_per_execution )를 Model.compile 에 전달하십시오. 이 예에서는 처리량이 약 50% 증가합니다.

with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                # Anything between 2 and `steps_per_epoch` could help here.
                steps_per_execution = 50,
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])

model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=steps_per_epoch,
          validation_data=test_dataset,
          validation_steps=validation_steps)
Epoch 1/5
300/300 [==============================] - 12s 41ms/step - loss: 0.1515 - sparse_categorical_accuracy: 0.9537 - val_loss: 0.0416 - val_sparse_categorical_accuracy: 0.9863
Epoch 2/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0366 - sparse_categorical_accuracy: 0.9891 - val_loss: 0.0410 - val_sparse_categorical_accuracy: 0.9875
Epoch 3/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0191 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.0432 - val_sparse_categorical_accuracy: 0.9865
Epoch 4/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.0447 - val_sparse_categorical_accuracy: 0.9875
Epoch 5/5
300/300 [==============================] - 3s 11ms/step - loss: 0.0093 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.0426 - val_sparse_categorical_accuracy: 0.9884
<keras.callbacks.History at 0x7f0d0463cd68>

사용자 지정 훈련 루프를 사용하여 모델 훈련

tf.functiontf.distribute API를 직접 사용하여 모델을 생성하고 훈련할 수도 있습니다. strategy.experimental_distribute_datasets_from_function API를 사용하여 데이터세트 함수가 지정된 데이터세트를 배포할 수 있습니다. 아래 예에서 데이터세트에 전달된 배치 크기는 전역 배치 크기가 아닌 복제본당 배치 크기입니다. 자세히 알아보려면 tf.distribute.Strategy를 사용한 맞춤형 교육 튜토리얼을 확인하세요.

먼저 모델, 데이터세트 및 tf.functions를 생성합니다.

# Create the model, optimizer and metrics inside the strategy scope, so that the
# variables can be mirrored on each device.
with strategy.scope():
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)

# Calculate per replica batch size, and distribute the datasets on each TPU
# worker.
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, is_training=True))

@tf.function
def train_step(iterator):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  strategy.run(step_fn, args=(next(iterator),))
WARNING:tensorflow:From <ipython-input-1-5625c2a14441>:15: StrategyBase.experimental_distribute_datasets_from_function (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
rename to distribute_datasets_from_function
WARNING:tensorflow:From <ipython-input-1-5625c2a14441>:15: StrategyBase.experimental_distribute_datasets_from_function (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
rename to distribute_datasets_from_function

그런 다음 훈련 루프를 실행합니다.

steps_per_eval = 10000 // batch_size

train_iterator = iter(train_dataset)
for epoch in range(5):
  print('Epoch: {}/5'.format(epoch))

  for step in range(steps_per_epoch):
    train_step(train_iterator)
  print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))
  training_loss.reset_states()
  training_accuracy.reset_states()
Epoch: 0/5
Current step: 300, training loss: 0.1339, accuracy: 95.79%
Epoch: 1/5
Current step: 600, training loss: 0.0333, accuracy: 98.91%
Epoch: 2/5
Current step: 900, training loss: 0.0176, accuracy: 99.43%
Epoch: 3/5
Current step: 1200, training loss: 0.0126, accuracy: 99.61%
Epoch: 4/5
Current step: 1500, training loss: 0.0122, accuracy: 99.61%

tf.function 내부의 여러 단계로 성능 향상

tf.function 내에서 여러 단계를 실행하여 성능을 향상시킬 수 있습니다. 이것은 tf.range 내부의 tf.functionstrategy.run 호출을 래핑하여 달성되며 AutoGraph는 이를 TPU 작업자의 tf.while_loop 으로 변환합니다.

향상된 성능에도 불구하고 tf.function 내에서 단일 단계를 실행하는 것과 비교하여 이 방법에는 절충점이 있습니다. tf.function 에서 여러 단계를 실행하는 것은 덜 유연합니다. 단계 내에서 열성적으로 또는 임의의 Python 코드를 실행할 수 없습니다.

@tf.function
def train_multiple_steps(iterator, steps):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  for _ in tf.range(steps):
    strategy.run(step_fn, args=(next(iterator),))

# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get 
# retraced if the value changes.
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))

print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))
Current step: 1800, training loss: 0.0081, accuracy: 99.74%

다음 단계