Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Formazione distribuita con Keras

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica notebook

Panoramica

L'API tf.distribute.Strategy fornisce un'astrazione per la distribuzione della formazione su più unità di elaborazione. L'obiettivo è consentire agli utenti di abilitare la formazione distribuita utilizzando i modelli e il codice di formazione esistenti, con modifiche minime.

Questo tutorial utilizza tf.distribute.MirroredStrategy , che esegue la replica nel grafico con addestramento sincrono su molte GPU su una macchina. In sostanza, copia tutte le variabili del modello su ciascun processore. Quindi, utilizza la riduzione totale per combinare i gradienti di tutti i processori e applica il valore combinato a tutte le copie del modello.

MirroredStrategy è una delle numerose strategie di distribuzione disponibili in TensorFlow core. Puoi leggere altre strategie nella guida alla strategia di distribuzione .

API Keras

Questo esempio utilizza l'API tf.keras per creare il modello e il ciclo di addestramento. Per i cicli di addestramento personalizzati, vedere il tutorial tf.distribute.Strategy con cicli di addestramento .

Importa dipendenze

# Import TensorFlow and TensorFlow Datasets

import tensorflow_datasets as tfds
import tensorflow as tf

import os
print(tf.__version__)
2.3.0

Scarica il set di dati

Scarica il set di dati MNIST e caricalo dai set di dati TensorFlow . Ciò restituisce un set di dati in formato tf.data .

L'impostazione di with_info su True include i metadati per l'intero set di dati, che viene salvato qui in info . Tra le altre cose, questo oggetto metadati include il numero di esempi di addestramento e test.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...

Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Definisci la strategia di distribuzione

Crea un oggetto MirroredStrategy . Questo gestirà la distribuzione e fornisce un gestore di contesto ( tf.distribute.MirroredStrategy.scope ) per costruire il tuo modello all'interno.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Configurazione della pipeline di input

Quando si addestra un modello con più GPU, è possibile utilizzare la potenza di calcolo aggiuntiva in modo efficace aumentando la dimensione del batch. In generale, utilizza la dimensione batch più grande che si adatta alla memoria della GPU e regola la velocità di apprendimento di conseguenza.

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

I valori dei pixel, che sono 0-255, devono essere normalizzati nell'intervallo 0-1 . Definisci questa scala in una funzione.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Applicare questa funzione ai dati di addestramento e di prova, mescolare i dati di addestramento e inviarli in batch per l'addestramento . Si noti che stiamo anche mantenendo una cache in memoria dei dati di allenamento per migliorare le prestazioni.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Crea il modello

Crea e compila il modello Keras nel contesto di strategy.scope .

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

Definisci i callback

I callback usati qui sono:

  • TensorBoard : questo callback scrive un log per TensorBoard che consente di visualizzare i grafici.
  • Punto di controllo del modello : questa richiamata salva il modello dopo ogni epoca.
  • Pianificazione della velocità di apprendimento : utilizzando questa richiamata, è possibile programmare la variazione della velocità di apprendimento dopo ogni epoca / batch.

A scopo illustrativo, aggiungere una richiamata di stampa per visualizzare la velocità di apprendimento nel blocco appunti.

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Allenati e valuta

Ora, addestra il modello nel solito modo, chiamando fit sul modello e passando il set di dati creato all'inizio del tutorial. Questo passaggio è lo stesso sia che tu stia distribuendo la formazione o meno.

model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

  1/938 [..............................] - ETA: 0s - loss: 2.3083 - accuracy: 0.0156WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks.

932/938 [============================>.] - ETA: 0s - loss: 0.1947 - accuracy: 0.9441
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 4s 4ms/step - loss: 0.1939 - accuracy: 0.9442
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Epoch 2/12
935/938 [============================>.] - ETA: 0s - loss: 0.0636 - accuracy: 0.9811
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 2s 3ms/step - loss: 0.0634 - accuracy: 0.9812
Epoch 3/12
936/938 [============================>.] - ETA: 0s - loss: 0.0438 - accuracy: 0.9864
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 2s 3ms/step - loss: 0.0439 - accuracy: 0.9864
Epoch 4/12
937/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9936
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0234 - accuracy: 0.9936
Epoch 5/12
932/938 [============================>.] - ETA: 0s - loss: 0.0204 - accuracy: 0.9948
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0204 - accuracy: 0.9948
Epoch 6/12
919/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9951
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0187 - accuracy: 0.9951
Epoch 7/12
921/938 [============================>.] - ETA: 0s - loss: 0.0172 - accuracy: 0.9960
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0171 - accuracy: 0.9960
Epoch 8/12
931/938 [============================>.] - ETA: 0s - loss: 0.0147 - accuracy: 0.9970
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0147 - accuracy: 0.9970
Epoch 9/12
938/938 [==============================] - ETA: 0s - loss: 0.0144 - accuracy: 0.9970
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0144 - accuracy: 0.9970
Epoch 10/12
924/938 [============================>.] - ETA: 0s - loss: 0.0143 - accuracy: 0.9971
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0142 - accuracy: 0.9971
Epoch 11/12
937/938 [============================>.] - ETA: 0s - loss: 0.0140 - accuracy: 0.9972
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0140 - accuracy: 0.9972
Epoch 12/12
923/938 [============================>.] - ETA: 0s - loss: 0.0139 - accuracy: 0.9973
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0139 - accuracy: 0.9973

<tensorflow.python.keras.callbacks.History at 0x7f50a0d94780>

Come puoi vedere di seguito, i checkpoint vengono salvati.

# check the checkpoint directory
ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index

Per vedere come si comporta il modello, carica l'ultimo checkpoint e chiama evaluate sui dati di test.

Chiama evaluate come prima di utilizzare set di dati appropriati.

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 6ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

Per vedere l'output, è possibile scaricare e visualizzare i log di TensorBoard sul terminale.

$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K
4.0K train

Esporta in SavedModel

Esporta il grafico e le variabili nel formato SavedModel indipendente dalla piattaforma. Dopo aver salvato il modello, è possibile caricarlo con o senza l'ambito.

path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: saved_model/assets

INFO:tensorflow:Assets written to: saved_model/assets

Carica il modello senza strategy.scope .

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 3ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

Carica il modello con strategy.scope .

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 4ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

Esempi e tutorial

Di seguito sono riportati alcuni esempi di utilizzo della strategia di distribuzione con keras fit / compile:

  1. Esempio di Transformer addestrato utilizzando tf.distribute.MirroredStrategy
  2. Esempio NCF addestrato utilizzando tf.distribute.MirroredStrategy .

Altri esempi elencati nella Guida alla strategia di distribuzione

Prossimi passi