Эта страница переведена с помощью Cloud Translation API.
Switch to English

Распределенное обучение с Керасом

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Обзор

API tf.distribute.Strategy предоставляет абстракцию для распределения вашего обучения по нескольким процессорам. Цель состоит в том, чтобы позволить пользователям включить распределенное обучение с использованием существующих моделей и обучающего кода с минимальными изменениями.

В этом руководстве используется tf.distribute.MirroredStrategy , который выполняет репликацию в графике с синхронным обучением на многих графических процессорах на одном компьютере. По сути, он копирует все переменные модели в каждый процессор. Затем он использует all-reduce для объединения градиентов от всех процессоров и применяет объединенное значение ко всем копиям модели.

MirroredStrategy - одна из нескольких стратегий распространения, доступных в ядре TensorFlow. Вы можете прочитать о других стратегиях в руководстве по стратегии распространения .

Keras API

В этом примере используется API tf.keras для построения модели и цикла обучения. Информацию о пользовательских циклах обучения см. В руководстве tf.distribute.Strategy с циклами обучения.

Импортировать зависимости

# Import TensorFlow and TensorFlow Datasets

import tensorflow_datasets as tfds
import tensorflow as tf

import os
print(tf.__version__)
2.3.0

Скачать набор данных

Загрузите набор данных MNIST и загрузите его из наборов данных TensorFlow . Это возвращает набор данных в формате tf.data .

Установка with_info на True включает метаданные для всего набора данных, которые сохраняются здесь в info . Помимо прочего, этот объект метаданных включает количество обучающих и тестовых примеров.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...

Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Определить стратегию распространения

Создайте объект MirroredStrategy . Это будет обрабатывать распространение и предоставляет диспетчер контекста ( tf.distribute.MirroredStrategy.scope ) для построения вашей модели внутри.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Настройка входного конвейера

При обучении модели с несколькими графическими процессорами вы можете эффективно использовать дополнительную вычислительную мощность, увеличив размер пакета. Как правило, используйте самый большой размер пакета, который соответствует памяти графического процессора, и соответствующим образом настраивайте скорость обучения.

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Значения пикселей 0–255 должны быть нормализованы до диапазона 0–1 . Определите этот масштаб в функции.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Примените эту функцию к обучающим и тестовым данным, перемешайте обучающие данные и запишите их для обучения . Обратите внимание, что для повышения производительности мы также сохраняем кэш данных обучения в памяти.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Создайте модель

Создайте и скомпилируйте модель Keras в контексте strategy.scope .

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

Определите обратные вызовы

Здесь используются следующие обратные вызовы:

  • TensorBoard : этот обратный вызов записывает журнал для TensorBoard, который позволяет вам визуализировать графики.
  • Контрольная точка модели : этот обратный вызов сохраняет модель после каждой эпохи.
  • Планировщик скорости обучения : используя этот обратный вызов, вы можете запланировать изменение скорости обучения после каждой эпохи / пакета.

Для наглядности добавьте обратный вызов печати для отображения скорости обучения в записной книжке.

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Тренируйте и оценивайте

Теперь обучите модель обычным способом, вызывая fit для модели и передавая набор данных, созданный в начале руководства. Этот шаг одинаков, независимо от того, распространяете вы тренинг или нет.

model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

  1/938 [..............................] - ETA: 0s - loss: 2.3083 - accuracy: 0.0156WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks.

932/938 [============================>.] - ETA: 0s - loss: 0.1947 - accuracy: 0.9441
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 4s 4ms/step - loss: 0.1939 - accuracy: 0.9442
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Epoch 2/12
935/938 [============================>.] - ETA: 0s - loss: 0.0636 - accuracy: 0.9811
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 2s 3ms/step - loss: 0.0634 - accuracy: 0.9812
Epoch 3/12
936/938 [============================>.] - ETA: 0s - loss: 0.0438 - accuracy: 0.9864
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 2s 3ms/step - loss: 0.0439 - accuracy: 0.9864
Epoch 4/12
937/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9936
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0234 - accuracy: 0.9936
Epoch 5/12
932/938 [============================>.] - ETA: 0s - loss: 0.0204 - accuracy: 0.9948
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0204 - accuracy: 0.9948
Epoch 6/12
919/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9951
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0187 - accuracy: 0.9951
Epoch 7/12
921/938 [============================>.] - ETA: 0s - loss: 0.0172 - accuracy: 0.9960
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0171 - accuracy: 0.9960
Epoch 8/12
931/938 [============================>.] - ETA: 0s - loss: 0.0147 - accuracy: 0.9970
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0147 - accuracy: 0.9970
Epoch 9/12
938/938 [==============================] - ETA: 0s - loss: 0.0144 - accuracy: 0.9970
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0144 - accuracy: 0.9970
Epoch 10/12
924/938 [============================>.] - ETA: 0s - loss: 0.0143 - accuracy: 0.9971
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0142 - accuracy: 0.9971
Epoch 11/12
937/938 [============================>.] - ETA: 0s - loss: 0.0140 - accuracy: 0.9972
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0140 - accuracy: 0.9972
Epoch 12/12
923/938 [============================>.] - ETA: 0s - loss: 0.0139 - accuracy: 0.9973
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0139 - accuracy: 0.9973

<tensorflow.python.keras.callbacks.History at 0x7f50a0d94780>

Как вы можете видеть ниже, контрольные точки сохраняются.

# check the checkpoint directory
ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index

Чтобы увидеть, как работает модель, загрузите последнюю контрольную точку и вызовите evaluate для тестовых данных.

Вызовите evaluate как и раньше, используя соответствующие наборы данных.

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 6ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

Чтобы увидеть результат, вы можете загрузить и просмотреть логи TensorBoard в терминале.

$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K
4.0K train

Экспорт в SavedModel

Экспортируйте график и переменные в формат SavedModel, не зависящий от платформы. После сохранения модели вы можете загрузить ее с прицелом или без него.

path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: saved_model/assets

INFO:tensorflow:Assets written to: saved_model/assets

Загрузите модель без strategy.scope .

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 3ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

Загрузите модель с помощью strategy.scope .

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 4ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

Примеры и руководства

Вот несколько примеров использования стратегии распространения с keras fit / compile:

  1. Пример трансформатора, обученного с использованием tf.distribute.MirroredStrategy
  2. Пример NCF, обученный с использованием tf.distribute.MirroredStrategy .

Дополнительные примеры приведены в руководстве по стратегии распространения.

Следующие шаги