![]() | ![]() | ![]() | ![]() |
Cloud TPU에 대한 실험적 지원은 현재 Keras 및 Google Colab에서 사용할 수 있습니다. 이 Colab 노트북을 실행하기 전에 노트북 설정 (런타임> 런타임 유형 변경> 하드웨어 가속기> TPU)을 확인하여 하드웨어 가속기가 TPU인지 확인하세요.
설정
import tensorflow as tf
import os
import tensorflow_datasets as tfds
TPU 초기화
TPU는 일반적으로 사용자 Python 프로그램을 실행하는 로컬 프로세스와 다른 Cloud TPU 작업자에 있습니다. 따라서 원격 클러스터에 연결하고 TPU를 초기화하려면 몇 가지 초기화 작업을 수행해야합니다. TPUClusterResolver
대한 tpu
인수는 TPUClusterResolver
전용 특수 주소입니다. Google Compute Engine (GCE)에서 실행중인 경우 대신 CloudTPU의 이름을 전달해야합니다.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.74:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.74:8470 INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Finished initializing TPU system. INFO:tensorflow:Finished initializing TPU system. All devices: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU')]
수동 장치 배치
TPU가 초기화 된 후 수동 기기 배치를 사용하여 단일 TPU 기기에 계산을 배치 할 수 있습니다.
a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
with tf.device('/TPU:0'):
c = tf.matmul(a, b)
print("c device: ", c.device)
print(c)
c device: /job:worker/replica:0/task:0/device:TPU:0 tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32)
유통 전략
대부분의 경우 사용자는 데이터 병렬 방식으로 여러 TPU에서 모델을 실행하려고합니다. 배포 전략은 CPU, GPU 또는 TPU에서 모델을 구동하는 데 사용할 수있는 추상화입니다. 배포 전략을 바꾸면 모델이 주어진 장치에서 실행됩니다. 자세한 내용은 배포 전략 가이드 를 참조하세요.
먼저 TPUStrategy
객체를 만듭니다.
strategy = tf.distribute.TPUStrategy(resolver)
INFO:tensorflow:Found TPU system: INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
모든 TPU 코어에서 실행될 수 있도록 계산을 복제하려면 간단히 strategy.run
API에 전달하면됩니다. 아래는 모든 코어가 동일한 입력 (a, b)
얻고 각 코어에서 독립적으로 matmul을 수행하는 예입니다. 출력은 모든 복제본의 값입니다.
@tf.function
def matmul_fn(x, y):
z = tf.matmul(x, y)
return z
z = strategy.run(matmul_fn, args=(a, b))
print(z)
PerReplica:{ 0: tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32), 1: tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32), 2: tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32), 3: tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32), 4: tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32), 5: tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32), 6: tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32), 7: tf.Tensor( [[22. 28.] [49. 64.]], shape=(2, 2), dtype=float32) }
TPU 분류
기본 개념을 배웠으므로 이제 더 구체적인 예를 살펴볼 차례입니다. 이 가이드에서는 배포 전략 tf.distribute.TPUStrategy
를 사용하여 Cloud TPU를 구동하고 tf.distribute.TPUStrategy
모델을 학습시키는 방법을 보여줍니다.
Keras 모델 정의
다음은 Keras를 사용하는 MNIST 모델의 정의이며 CPU 또는 GPU에서 사용하는 것과 변경되지 않았습니다. Keras 모델 생성은 strategy.scope
내에 있어야하므로 각 TPU 기기에서 변수를 생성 할 수 있습니다. 코드의 다른 부분은 전략 범위 내에있을 필요가 없습니다.
def create_model():
return tf.keras.Sequential(
[tf.keras.layers.Conv2D(256, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(256, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)])
입력 데이터 세트
Cloud TPU를 사용할 때tf.data.Dataset
API를 효율적으로 사용하는 것이 중요합니다. 데이터를 충분히 빠르게 공급할 수 없으면 Cloud TPU를 사용할 수 없기 때문입니다. 데이터 세트 성능에 대한 자세한 내용은 입력 파이프 라인 성능 가이드 를 참조하세요.
가장 간단한 실험 ( tf.data.Dataset.from_tensor_slices
또는 기타 그래프 내 데이터 사용)을 제외한 모든 실험의 경우 데이터 세트에서 읽은 모든 데이터 파일을 Google Cloud Storage (GCS) 버킷에 저장해야합니다.
대부분의 사용 사례에서는 데이터를 TFRecord
형식으로 변환하고 tf.data.TFRecordDataset
을 사용하여 읽는 것이 좋습니다. 이를 수행하는 방법에 대한 자세한 내용은 TFRecord 및 tf.Example 자습서 를 참조하십시오. 그러나 이것은 어려운 요구 사항이 아니며 원하는 경우 다른 데이터 세트 판독기 ( FixedLengthRecordDataset
또는 TextLineDataset
)를 사용할 수 있습니다.
작은 데이터 세트는 tf.data.Dataset.cache
사용하여 완전히 메모리에로드 할 수 있습니다.
사용 된 데이터 형식에 관계없이 100MB 정도의 대용량 파일을 사용하는 것이 좋습니다. 이것은 파일을 여는 오버 헤드가 훨씬 더 높기 때문에이 네트워크 설정에서 특히 중요합니다.
여기에서 tensorflow_datasets
모듈을 사용하여 MNIST 학습 데이터의 사본을 tensorflow_datasets
합니다. try_gcs
는 공개 GCS 버킷에서 사용할 수있는 사본을 사용하도록 지정됩니다. 이를 지정하지 않으면 TPU가 다운로드 된 데이터에 액세스 할 수 없습니다.
def get_dataset(batch_size, is_training=True):
split = 'train' if is_training else 'test'
dataset, info = tfds.load(name='mnist', split=split, with_info=True,
as_supervised=True, try_gcs=True)
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255.0
return image, label
dataset = dataset.map(scale)
# Only shuffle and repeat the dataset in training. The advantage to have a
# infinite dataset for training is to avoid the potential last partial batch
# in each epoch, so users don't need to think about scaling the gradients
# based on the actual batch size.
if is_training:
dataset = dataset.shuffle(10000)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
return dataset
Keras 상위 수준 API를 사용하여 모델 학습
Keras fit / compile API를 사용하여 모델을 간단하게 학습 할 수 있습니다. 여기에 TPU와 관련된 것은 없습니다. 여러 GPU가 있고 TPUStrategy
대신 MirroredStrategy
사용하는 경우 아래에 동일한 코드를 작성합니다. 자세한 내용 은 Keras를 사용한 분산 교육 자습서를 확인하십시오.
with strategy.scope():
model = create_model()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['sparse_categorical_accuracy'])
batch_size = 200
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size
train_dataset = get_dataset(batch_size, is_training=True)
test_dataset = get_dataset(batch_size, is_training=False)
model.fit(train_dataset,
epochs=5,
steps_per_epoch=steps_per_epoch,
validation_data=test_dataset,
validation_steps=validation_steps)
Epoch 1/5 300/300 [==============================] - 19s 38ms/step - loss: 0.3196 - sparse_categorical_accuracy: 0.8996 - val_loss: 0.0467 - val_sparse_categorical_accuracy: 0.9855 Epoch 2/5 300/300 [==============================] - 7s 25ms/step - loss: 0.0381 - sparse_categorical_accuracy: 0.9878 - val_loss: 0.0379 - val_sparse_categorical_accuracy: 0.9885 Epoch 3/5 300/300 [==============================] - 8s 26ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9933 - val_loss: 0.0366 - val_sparse_categorical_accuracy: 0.9892 Epoch 4/5 300/300 [==============================] - 8s 25ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.0327 - val_sparse_categorical_accuracy: 0.9902 Epoch 5/5 300/300 [==============================] - 8s 25ms/step - loss: 0.0083 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.0465 - val_sparse_categorical_accuracy: 0.9885 <tensorflow.python.keras.callbacks.History at 0x7fc0906c90b8>
Python 오버 헤드를 줄이고 TPU 성능을 최대화하려면 Model.compile
대한 실험적 experimental_steps_per_execution
인수를 사용 Model.compile
. 여기에서 처리량이 약 50 % 증가합니다.
with strategy.scope():
model = create_model()
model.compile(optimizer='adam',
# Anything between 2 and `steps_per_epoch` could help here.
experimental_steps_per_execution = 50,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['sparse_categorical_accuracy'])
model.fit(train_dataset,
epochs=5,
steps_per_epoch=steps_per_epoch,
validation_data=test_dataset,
validation_steps=validation_steps)
WARNING:tensorflow:The argument `steps_per_execution` is no longer experimental. Pass `steps_per_execution` instead of `experimental_steps_per_execution`. WARNING:tensorflow:The argument `steps_per_execution` is no longer experimental. Pass `steps_per_execution` instead of `experimental_steps_per_execution`. Epoch 1/5 300/300 [==============================] - 14s 46ms/step - loss: 0.2399 - sparse_categorical_accuracy: 0.9279 - val_loss: 0.0510 - val_sparse_categorical_accuracy: 0.9838 Epoch 2/5 300/300 [==============================] - 5s 15ms/step - loss: 0.0406 - sparse_categorical_accuracy: 0.9876 - val_loss: 0.0409 - val_sparse_categorical_accuracy: 0.9882 Epoch 3/5 300/300 [==============================] - 5s 15ms/step - loss: 0.0203 - sparse_categorical_accuracy: 0.9936 - val_loss: 0.0394 - val_sparse_categorical_accuracy: 0.9879 Epoch 4/5 300/300 [==============================] - 5s 15ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.0410 - val_sparse_categorical_accuracy: 0.9879 Epoch 5/5 300/300 [==============================] - 5s 15ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9953 - val_loss: 0.0508 - val_sparse_categorical_accuracy: 0.9863 <tensorflow.python.keras.callbacks.History at 0x7fbc743646d8>
사용자 지정 학습 루프를 사용하여 모델을 학습합니다.
tf.function
및 tf.distribute
API를 직접 사용하여 모델을 만들고 학습시킬 수도 있습니다. strategy.experimental_distribute_datasets_from_function
API는 주어진 데이터 세트 함수를 통해 데이터 세트를 배포하는 데 사용됩니다. 이 경우 데이터 세트에 전달되는 일괄 처리 크기는 전역 일괄 처리 크기가 아니라 복제본 일괄 처리 크기입니다. 자세한 내용 은 tf.distribute.Strategy를 사용한 사용자 지정 교육 자습서를 확인하십시오.
먼저 모델, 데이터 세트 및 tf.functions를 만듭니다.
# Create the model, optimizer and metrics inside strategy scope, so that the
# variables can be mirrored on each device.
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
# Calculate per replica batch size, and distribute the datasets on each TPU
# worker.
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
train_dataset = strategy.experimental_distribute_datasets_from_function(
lambda _: get_dataset(per_replica_batch_size, is_training=True))
@tf.function
def train_step(iterator):
"""The step function for one training step"""
def step_fn(inputs):
"""The computation to run on each TPU device."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
training_loss.update_state(loss * strategy.num_replicas_in_sync)
training_accuracy.update_state(labels, logits)
strategy.run(step_fn, args=(next(iterator),))
WARNING:tensorflow:From <ipython-input-1-2f075cabff81>:15: StrategyBase.experimental_distribute_datasets_from_function (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version. Instructions for updating: rename to distribute_datasets_from_function WARNING:tensorflow:From <ipython-input-1-2f075cabff81>:15: StrategyBase.experimental_distribute_datasets_from_function (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version. Instructions for updating: rename to distribute_datasets_from_function
그런 다음 훈련 루프를 실행하십시오.
steps_per_eval = 10000 // batch_size
train_iterator = iter(train_dataset)
for epoch in range(5):
print('Epoch: {}/5'.format(epoch))
for step in range(steps_per_epoch):
train_step(train_iterator)
print('Current step: {}, training loss: {}, accuracy: {}%'.format(
optimizer.iterations.numpy(),
round(float(training_loss.result()), 4),
round(float(training_accuracy.result()) * 100, 2)))
training_loss.reset_states()
training_accuracy.reset_states()
Epoch: 0/5 Current step: 300, training loss: 0.1237, accuracy: 96.17% Epoch: 1/5 Current step: 600, training loss: 0.0333, accuracy: 98.98% Epoch: 2/5 Current step: 900, training loss: 0.0185, accuracy: 99.37% Epoch: 3/5 Current step: 1200, training loss: 0.0135, accuracy: 99.54% Epoch: 4/5 Current step: 1500, training loss: 0.0092, accuracy: 99.7%
tf.function
내에서 여러 단계로 성능 향상
tf.function
내에서 여러 단계를 실행하여 성능을 향상시킬 수 있습니다. 이는 tf.range
내부의 tf.function
strategy.run
호출을 래핑하여 수행되며, AutoGraph는이를 TPU 작업자의 tf.while_loop
로 변환합니다.
더 나은 성능이지만 tf.function
내부의 단일 단계와 비교할 때 tf.function
있습니다. tf.function
에서 여러 단계를 실행하는 것은 유연성이 떨어 tf.function
단계 내에서 열심히 또는 임의의 파이썬 코드를 실행할 수 없습니다.
@tf.function
def train_multiple_steps(iterator, steps):
"""The step function for one training step"""
def step_fn(inputs):
"""The computation to run on each TPU device."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
training_loss.update_state(loss * strategy.num_replicas_in_sync)
training_accuracy.update_state(labels, logits)
for _ in tf.range(steps):
strategy.run(step_fn, args=(next(iterator),))
# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get
# retraced if the value changes.
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))
print('Current step: {}, training loss: {}, accuracy: {}%'.format(
optimizer.iterations.numpy(),
round(float(training_loss.result()), 4),
round(float(training_accuracy.result()) * 100, 2)))
Current step: 1800, training loss: 0.009, accuracy: 99.71%
다음 단계
- Google Cloud TPU 문서-Google Cloud TPU를 설정하고 실행합니다.
- TensorFlow 를 사용한 분산 교육 -배포 전략을 사용하는 방법과 모범 사례를 보여주는 여러 예제에 대한 링크입니다.
- TensorFlow로 모델 저장 /로드 -배포 전략을 사용하여 모델을 저장하고로드하는 방법입니다.
- TensorFlow 공식 모델 -Cloud TPU와 호환되는 최첨단 TensorFlow 2.x 모델의 예입니다.
- Google Cloud TPU 성능 가이드 -애플리케이션의 Cloud TPU 구성 매개 변수를 조정하여 Cloud TPU 성능을 더욱 향상시킵니다.