Esta página foi traduzida pela API Cloud Translation.
Switch to English

Treinamento distribuído com Keras

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Visão geral

A API tf.distribute.Strategy fornece uma abstração para distribuir seu treinamento em várias unidades de processamento. O objetivo é permitir que os usuários habilitem o treinamento distribuído usando modelos existentes e código de treinamento, com mudanças mínimas.

Este tutorial usa o tf.distribute.MirroredStrategy , que faz replicação no gráfico com treinamento síncrono em muitas GPUs em uma máquina. Essencialmente, ele copia todas as variáveis ​​do modelo para cada processador. Em seguida, ele usa all-reduzir para combinar os gradientes de todos os processadores e aplica o valor combinado a todas as cópias do modelo.

MirroredStrategy é uma das várias estratégias de distribuição disponíveis no núcleo do TensorFlow. Você pode ler sobre mais estratégias no guia de estratégia de distribuição .

API Keras

Este exemplo usa a API tf.keras para construir o modelo e o loop de treinamento. Para loops de treinamento personalizados, consulte o tutorial tf.distribute.Strategy com loops de treinamento

Importar dependências

# Import TensorFlow and TensorFlow Datasets

import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os
print(tf.__version__)
2.2.0

Baixe o conjunto de dados

Faça o download do conjunto de dados MNIST e carregue-o dos conjuntos de dados TensorFlow . Isso retorna um conjunto de dados no formato tf.data .

Definir with_info como True inclui os metadados para todo o conjunto de dados, que está sendo salvo aqui para info . Entre outras coisas, esse objeto de metadados inclui o número de exemplos de treinamento e teste.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...

Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Definir estratégia de distribuição

Crie um objeto MirroredStrategy . Isso tratará da distribuição e fornecerá um gerenciador de contexto ( tf.distribute.MirroredStrategy.scope ) para construir seu modelo interno.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Configurar pipeline de entrada

Ao treinar um modelo com várias GPUs, você pode usar o poder de computação extra de forma eficaz, aumentando o tamanho do lote. Em geral, use o maior tamanho de lote que cabe na memória da GPU e ajuste a taxa de aprendizagem de acordo.

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Os valores de pixel, que são 0-255, devem ser normalizados para o intervalo 0-1 . Defina esta escala em uma função.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Aplique esta função aos dados de treinamento e teste, embaralhe os dados de treinamento e agrupe-os para treinamento . Observe que também mantemos um cache na memória dos dados de treinamento para melhorar o desempenho.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Crie o modelo

Crie e compile o modelo Keras no contexto de strategy.scope .

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

Defina os callbacks

Os callbacks usados ​​aqui são:

  • TensorBoard : esse retorno de chamada grava um registro para o TensorBoard que permite visualizar os gráficos.
  • Ponto de verificação do modelo : este retorno de chamada salva o modelo após cada época.
  • Programador de taxa de aprendizagem : usando este retorno de chamada, você pode programar a taxa de aprendizagem para mudar após cada época / lote.

Para fins ilustrativos, adicione um retorno de chamada de impressão para exibir a taxa de aprendizagem no bloco de notas.

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Treinar e avaliar

Agora, treine o modelo da maneira usual, chamando o fit no modelo e passando o conjunto de dados criado no início do tutorial. Esta etapa é a mesma, quer você esteja distribuindo o treinamento ou não.

model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

936/938 [============================>.] - ETA: 0s - accuracy: 0.9422 - loss: 0.2016
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 5s 5ms/step - accuracy: 0.9422 - loss: 0.2015 - lr: 0.0010
Epoch 2/12
936/938 [============================>.] - ETA: 0s - accuracy: 0.9807 - loss: 0.0662
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 3s 4ms/step - accuracy: 0.9807 - loss: 0.0662 - lr: 0.0010
Epoch 3/12
933/938 [============================>.] - ETA: 0s - accuracy: 0.9863 - loss: 0.0464
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9863 - loss: 0.0464 - lr: 0.0010
Epoch 4/12
933/938 [============================>.] - ETA: 0s - accuracy: 0.9933 - loss: 0.0252
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 3s 4ms/step - accuracy: 0.9933 - loss: 0.0252 - lr: 1.0000e-04
Epoch 5/12
932/938 [============================>.] - ETA: 0s - accuracy: 0.9946 - loss: 0.0220
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9945 - loss: 0.0220 - lr: 1.0000e-04
Epoch 6/12
929/938 [============================>.] - ETA: 0s - accuracy: 0.9951 - loss: 0.0200
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9951 - loss: 0.0201 - lr: 1.0000e-04
Epoch 7/12
928/938 [============================>.] - ETA: 0s - accuracy: 0.9955 - loss: 0.0186
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9955 - loss: 0.0186 - lr: 1.0000e-04
Epoch 8/12
934/938 [============================>.] - ETA: 0s - accuracy: 0.9965 - loss: 0.0161
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9965 - loss: 0.0161 - lr: 1.0000e-05
Epoch 9/12
932/938 [============================>.] - ETA: 0s - accuracy: 0.9965 - loss: 0.0157
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9965 - loss: 0.0156 - lr: 1.0000e-05
Epoch 10/12
934/938 [============================>.] - ETA: 0s - accuracy: 0.9966 - loss: 0.0155
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9966 - loss: 0.0154 - lr: 1.0000e-05
Epoch 11/12
934/938 [============================>.] - ETA: 0s - accuracy: 0.9967 - loss: 0.0153
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9967 - loss: 0.0153 - lr: 1.0000e-05
Epoch 12/12
924/938 [============================>.] - ETA: 0s - accuracy: 0.9967 - loss: 0.0152
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9967 - loss: 0.0151 - lr: 1.0000e-05

<tensorflow.python.keras.callbacks.History at 0x7fc7cc1ce5f8>

Como você pode ver abaixo, os pontos de verificação estão sendo salvos.

# check the checkpoint directory
!ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00002
ckpt_1.data-00000-of-00002   ckpt_4.data-00001-of-00002
ckpt_1.data-00001-of-00002   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00002
ckpt_10.data-00000-of-00002  ckpt_5.data-00001-of-00002
ckpt_10.data-00001-of-00002  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00002
ckpt_11.data-00000-of-00002  ckpt_6.data-00001-of-00002
ckpt_11.data-00001-of-00002  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00002
ckpt_12.data-00000-of-00002  ckpt_7.data-00001-of-00002
ckpt_12.data-00001-of-00002  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00002
ckpt_2.data-00000-of-00002   ckpt_8.data-00001-of-00002
ckpt_2.data-00001-of-00002   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00002
ckpt_3.data-00000-of-00002   ckpt_9.data-00001-of-00002
ckpt_3.data-00001-of-00002   ckpt_9.index
ckpt_3.index

Para ver o desempenho do modelo, carregue o ponto de verificação mais recente e chame evaluate nos dados de teste.

Chame a evaluate como antes, usando conjuntos de dados apropriados.

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

157/157 [==============================] - 1s 7ms/step - accuracy: 0.9861 - loss: 0.0393
Eval loss: 0.039307601749897, Eval Accuracy: 0.9861000180244446

Para ver o resultado, você pode fazer o download e visualizar os registros do TensorBoard no terminal.

$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K
4.0K train

Exportar para SavedModel

Exporte o gráfico e as variáveis ​​para o formato SavedModel independente de plataforma. Depois que seu modelo for salvo, você pode carregá-lo com ou sem o osciloscópio.

path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Assets written to: saved_model/assets

INFO:tensorflow:Assets written to: saved_model/assets

Carregue o modelo sem strategy.scope .

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 5ms/step - loss: 0.0393 - accuracy: 0.9861
Eval loss: 0.039307601749897, Eval Accuracy: 0.9861000180244446

Carregue o modelo com strategy.scope .

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 6ms/step - accuracy: 0.9861 - loss: 0.0393
Eval loss: 0.039307601749897, Eval Accuracy: 0.9861000180244446

Exemplos e tutoriais

Aqui estão alguns exemplos de uso de estratégia de distribuição com keras fit / compile:

  1. Exemplo de transformador treinado usando tf.distribute.MirroredStrategy
  2. Exemplo de NCF treinado usando tf.distribute.MirroredStrategy .

Mais exemplos listados no guia de estratégia de distribuição

Próximos passos