Szkolenie rozproszone z Keras

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Przegląd

tf.distribute.Strategy API tf.distribute.Strategy zapewnia abstrakcję do dystrybucji szkolenia na wiele jednostek przetwarzania. Celem jest umożliwienie użytkownikom włączenia szkolenia rozproszonego przy użyciu istniejących modeli i kodu szkoleniowego przy minimalnych zmianach.

W tym samouczku zastosowano tf.distribute.MirroredStrategy , który wykonuje replikację na wykresie z synchronicznym trenowaniem na wielu procesorach GPU na jednym komputerze. Zasadniczo kopiuje wszystkie zmienne modelu do każdego procesora. Następnie używa all-reduce do połączenia gradientów ze wszystkich procesorów i stosuje połączoną wartość do wszystkich kopii modelu.

MirroredStrategy to jedna z kilku strategii dystrybucji dostępnych w rdzeniu TensorFlow. Więcej informacji o strategiach można znaleźć w przewodniku po strategiach dystrybucji .

Keras API

Ten przykład używa interfejsu API tf.keras do tf.keras modelu i pętli szkoleniowej. Aby zapoznać się z niestandardowymi pętlami treningowymi, zobacz samouczek tf.distribute.Strategy with training loops .

Importuj zależności

# Import TensorFlow and TensorFlow Datasets

import tensorflow_datasets as tfds
import tensorflow as tf

import os
print(tf.__version__)
2.5.0

Pobierz zbiór danych

Pobierz zbiór danych MNIST i załaduj go z TensorFlow Datasets . Zwraca zestaw danych w formacie tf.data .

Ustawienie with_info na True obejmuje metadane dla całego zbioru danych, który jest zapisywany tutaj w info . Ten obiekt metadanych zawiera między innymi liczbę pociągów i przykładów testowych.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

Zdefiniuj strategię dystrybucji

Utwórz obiekt MirroredStrategy . tf.distribute.MirroredStrategy.scope to dystrybucję i zapewni menedżera kontekstu ( tf.distribute.MirroredStrategy.scope ) do zbudowania modelu wewnątrz.

strategy = tf.distribute.MirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Konfiguracja potoku wejściowego

Podczas trenowania modelu z wieloma procesorami GPU można efektywnie wykorzystać dodatkową moc obliczeniową, zwiększając rozmiar partii. Ogólnie rzecz biorąc, używaj największego rozmiaru partii, który pasuje do pamięci GPU i odpowiednio dostosuj szybkość uczenia się.

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Wartości pikseli, które są od 0 do 255, muszą być znormalizowane do zakresu 0-1 . Zdefiniuj tę skalę w funkcji.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Zastosuj tę funkcję do danych treningowych i testowych, przetasuj dane treningowe i wsadź je do treningu . Zwróć uwagę, że przechowujemy również pamięć podręczną danych treningowych, aby poprawić wydajność.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Stwórz model

Utwórz i skompiluj model Keras w kontekście strategy.scope .

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Zdefiniuj wywołania zwrotne

Używane tutaj wywołania zwrotne to:

  • TensorBoard : Ta funkcja zwrotna zapisuje dziennik dla TensorBoard, który umożliwia wizualizację wykresów.
  • Punkt kontrolny modelu : to wywołanie zwrotne zapisuje model po każdej epoce.
  • Harmonogram szybkości uczenia się : Korzystając z tego wywołania zwrotnego, możesz zaplanować zmianę szybkości uczenia się po każdej epoce/partii.

W celach ilustracyjnych dodaj wywołanie zwrotne drukowania, aby wyświetlić wskaźnik uczenia się w notatniku.

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Trenuj i oceniaj

Teraz przeszkol model w zwykły sposób, wywołując fit do modelu i przekazując zestaw danych utworzony na początku samouczka. Ten krok jest taki sam, niezależnie od tego, czy prowadzisz szkolenie, czy nie.

model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
3/938 [..............................] - ETA: 3:57 - loss: 2.2014 - accuracy: 0.2292WARNING:tensorflow:Callback method `on_train_batch_begin` is slow compared to the batch time (batch time: 0.0043s vs `on_train_batch_begin` time: 0.0693s). Check your callbacks.
WARNING:tensorflow:Callback method `on_train_batch_begin` is slow compared to the batch time (batch time: 0.0043s vs `on_train_batch_begin` time: 0.0693s). Check your callbacks.
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0043s vs `on_train_batch_end` time: 0.0141s). Check your callbacks.
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0043s vs `on_train_batch_end` time: 0.0141s). Check your callbacks.
938/938 [==============================] - 8s 4ms/step - loss: 0.1970 - accuracy: 0.9419

Learning rate for epoch 1 is 0.0010000000474974513
Epoch 2/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0641 - accuracy: 0.9809

Learning rate for epoch 2 is 0.0010000000474974513
Epoch 3/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0432 - accuracy: 0.9868

Learning rate for epoch 3 is 0.0010000000474974513
Epoch 4/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0228 - accuracy: 0.9937

Learning rate for epoch 4 is 9.999999747378752e-05
Epoch 5/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0194 - accuracy: 0.9948

Learning rate for epoch 5 is 9.999999747378752e-05
Epoch 6/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0175 - accuracy: 0.9956

Learning rate for epoch 6 is 9.999999747378752e-05
Epoch 7/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0160 - accuracy: 0.9962

Learning rate for epoch 7 is 9.999999747378752e-05
Epoch 8/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0136 - accuracy: 0.9971

Learning rate for epoch 8 is 9.999999747378752e-06
Epoch 9/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0133 - accuracy: 0.9972

Learning rate for epoch 9 is 9.999999747378752e-06
Epoch 10/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0131 - accuracy: 0.9973

Learning rate for epoch 10 is 9.999999747378752e-06
Epoch 11/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0130 - accuracy: 0.9973

Learning rate for epoch 11 is 9.999999747378752e-06
Epoch 12/12
938/938 [==============================] - 2s 3ms/step - loss: 0.0128 - accuracy: 0.9974

Learning rate for epoch 12 is 9.999999747378752e-06
<tensorflow.python.keras.callbacks.History at 0x7f3d78283790>

Jak widać poniżej, punkty kontrolne są zapisywane.

# check the checkpoint directory
ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index

Aby zobaczyć, jak działa model, załaduj najnowszy punkt kontrolny i wywołaj evaluate danych testowych.

Zadzwoń do evaluate jak przed użyciem odpowiednich zestawów danych.

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 2s 3ms/step - loss: 0.0383 - accuracy: 0.9870
Eval loss: 0.0383150540292263, Eval Accuracy: 0.9869999885559082

Aby zobaczyć wyniki, możesz pobrać i wyświetlić logi TensorBoard na terminalu.

$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K
4.0K train

Eksportuj do zapisanego modelu

Wyeksportuj wykres i zmienne do formatu SavedModel niezależnego od platformy. Po zapisaniu modelu możesz go wczytać z zakresem lub bez niego.

path = 'saved_model/'
model.save(path, save_format='tf')
INFO:tensorflow:Assets written to: saved_model/assets
INFO:tensorflow:Assets written to: saved_model/assets

Załaduj model bez strategy.scope .

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 0s 2ms/step - loss: 0.0383 - accuracy: 0.9870
Eval loss: 0.0383150540292263, Eval Accuracy: 0.9869999885559082

Załaduj model za pomocą strategy.scope .

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 3s 2ms/step - loss: 0.0383 - accuracy: 0.9870
Eval loss: 0.0383150540292263, Eval Accuracy: 0.9869999885559082

Przykłady i samouczki

Oto kilka przykładów użycia strategii dystrybucji z keras fit/compile:

  1. Przykład transformatora wyszkolony przy użyciu tf.distribute.MirroredStrategy
  2. Przykład NCF przeszkolony przy użyciu tf.distribute.MirroredStrategy .

Więcej przykładów wymienionych w przewodniku po strategii dystrybucji

Następne kroki