![]() | ![]() | ![]() | ![]() |
visão global
A API tf.distribute.Strategy
fornece uma abstração para distribuir seu treinamento em várias unidades de processamento. O objetivo é permitir que os usuários habilitem o treinamento distribuído usando modelos existentes e código de treinamento, com mudanças mínimas.
Este tutorial usa o tf.distribute.MirroredStrategy
, que faz replicação no gráfico com treinamento síncrono em muitas GPUs em uma máquina. Essencialmente, ele copia todas as variáveis do modelo para cada processador. Em seguida, ele usa all-reduzir para combinar os gradientes de todos os processadores e aplica o valor combinado a todas as cópias do modelo.
MirroredStrategy
é uma das várias estratégias de distribuição disponíveis no núcleo do TensorFlow. Você pode ler sobre mais estratégias no guia de estratégia de distribuição .
API Keras
Este exemplo usa a API tf.keras
para construir o modelo e o loop de treinamento. Para obter loops de treinamento personalizados, consulte o tutorial tf.distribute.Strategy com loops de treinamento .
Importar dependências
# Import TensorFlow and TensorFlow Datasets
import tensorflow_datasets as tfds
import tensorflow as tf
import os
print(tf.__version__)
2.3.0
Baixe o conjunto de dados
Faça o download do conjunto de dados MNIST e carregue-o dos conjuntos de dados TensorFlow . Isso retorna um conjunto de dados no formato tf.data
.
Definir with_info
como True
inclui os metadados para todo o conjunto de dados, que está sendo salvo aqui para info
. Entre outras coisas, esse objeto de metadados inclui o número de exemplos de treinamento e teste.
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`. Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Definir estratégia de distribuição
Crie um objeto MirroredStrategy
. Isso tratará da distribuição e fornecerá um gerenciador de contexto ( tf.distribute.MirroredStrategy.scope
) para construir seu modelo interno.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1
Configurar pipeline de entrada
Ao treinar um modelo com várias GPUs, você pode usar o poder de computação extra de forma eficaz, aumentando o tamanho do lote. Em geral, use o maior tamanho de lote que cabe na memória da GPU e ajuste a taxa de aprendizagem de acordo.
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
Os valores de pixel, que são 0-255, devem ser normalizados para o intervalo 0-1 . Defina esta escala em uma função.
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Aplique esta função aos dados de treinamento e teste, embaralhe os dados de treinamento e agrupe-os para treinamento . Observe que também mantemos um cache na memória dos dados de treinamento para melhorar o desempenho.
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
Crie o modelo
Crie e compile o modelo Keras no contexto de strategy.scope
.
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
Defina os callbacks
Os callbacks usados aqui são:
- TensorBoard : esse retorno de chamada grava um registro para o TensorBoard que permite visualizar os gráficos.
- Ponto de verificação do modelo : este retorno de chamada salva o modelo após cada época.
- Programador de taxa de aprendizagem : usando este retorno de chamada, você pode programar a taxa de aprendizagem para mudar após cada época / lote.
Para fins ilustrativos, adicione um retorno de chamada de impressão para exibir a taxa de aprendizagem no bloco de notas.
# Define the checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
model.optimizer.lr.numpy()))
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
Treinar e avaliar
Agora, treine o modelo da maneira usual, chamando o fit
no modelo e passando o conjunto de dados criado no início do tutorial. Esta etapa é a mesma, quer você esteja distribuindo o treinamento ou não.
model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 1/938 [..............................] - ETA: 0s - loss: 2.3083 - accuracy: 0.0156WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks. Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks. 932/938 [============================>.] - ETA: 0s - loss: 0.1947 - accuracy: 0.9441 Learning rate for epoch 1 is 0.0010000000474974513 938/938 [==============================] - 4s 4ms/step - loss: 0.1939 - accuracy: 0.9442 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 2/12 935/938 [============================>.] - ETA: 0s - loss: 0.0636 - accuracy: 0.9811 Learning rate for epoch 2 is 0.0010000000474974513 938/938 [==============================] - 2s 3ms/step - loss: 0.0634 - accuracy: 0.9812 Epoch 3/12 936/938 [============================>.] - ETA: 0s - loss: 0.0438 - accuracy: 0.9864 Learning rate for epoch 3 is 0.0010000000474974513 938/938 [==============================] - 2s 3ms/step - loss: 0.0439 - accuracy: 0.9864 Epoch 4/12 937/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9936 Learning rate for epoch 4 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0234 - accuracy: 0.9936 Epoch 5/12 932/938 [============================>.] - ETA: 0s - loss: 0.0204 - accuracy: 0.9948 Learning rate for epoch 5 is 9.999999747378752e-05 938/938 [==============================] - 3s 3ms/step - loss: 0.0204 - accuracy: 0.9948 Epoch 6/12 919/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9951 Learning rate for epoch 6 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0187 - accuracy: 0.9951 Epoch 7/12 921/938 [============================>.] - ETA: 0s - loss: 0.0172 - accuracy: 0.9960 Learning rate for epoch 7 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0171 - accuracy: 0.9960 Epoch 8/12 931/938 [============================>.] - ETA: 0s - loss: 0.0147 - accuracy: 0.9970 Learning rate for epoch 8 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0147 - accuracy: 0.9970 Epoch 9/12 938/938 [==============================] - ETA: 0s - loss: 0.0144 - accuracy: 0.9970 Learning rate for epoch 9 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0144 - accuracy: 0.9970 Epoch 10/12 924/938 [============================>.] - ETA: 0s - loss: 0.0143 - accuracy: 0.9971 Learning rate for epoch 10 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0142 - accuracy: 0.9971 Epoch 11/12 937/938 [============================>.] - ETA: 0s - loss: 0.0140 - accuracy: 0.9972 Learning rate for epoch 11 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0140 - accuracy: 0.9972 Epoch 12/12 923/938 [============================>.] - ETA: 0s - loss: 0.0139 - accuracy: 0.9973 Learning rate for epoch 12 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0139 - accuracy: 0.9973 <tensorflow.python.keras.callbacks.History at 0x7f50a0d94780>
Como você pode ver abaixo, os pontos de verificação estão sendo salvos.
# check the checkpoint directory
ls {checkpoint_dir}
checkpoint ckpt_4.data-00000-of-00001 ckpt_1.data-00000-of-00001 ckpt_4.index ckpt_1.index ckpt_5.data-00000-of-00001 ckpt_10.data-00000-of-00001 ckpt_5.index ckpt_10.index ckpt_6.data-00000-of-00001 ckpt_11.data-00000-of-00001 ckpt_6.index ckpt_11.index ckpt_7.data-00000-of-00001 ckpt_12.data-00000-of-00001 ckpt_7.index ckpt_12.index ckpt_8.data-00000-of-00001 ckpt_2.data-00000-of-00001 ckpt_8.index ckpt_2.index ckpt_9.data-00000-of-00001 ckpt_3.data-00000-of-00001 ckpt_9.index ckpt_3.index
Para ver o desempenho do modelo, carregue o ponto de verificação mais recente e chame evaluate
nos dados de teste.
Chame a evaluate
como antes, usando conjuntos de dados apropriados.
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 6ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
Para ver o resultado, você pode fazer o download e visualizar os registros do TensorBoard no terminal.
$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K 4.0K train
Exportar para SavedModel
Exporte o gráfico e as variáveis para o formato SavedModel independente de plataforma. Depois que seu modelo for salvo, você pode carregá-lo com ou sem o osciloscópio.
path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: saved_model/assets INFO:tensorflow:Assets written to: saved_model/assets
Carregue o modelo sem strategy.scope
.
unreplicated_model = tf.keras.models.load_model(path)
unreplicated_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 3ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
Carregue o modelo com strategy.scope
.
with strategy.scope():
replicated_model = tf.keras.models.load_model(path)
replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 4ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
Exemplos e tutoriais
Aqui estão alguns exemplos para usar a estratégia de distribuição com keras fit / compile:
- Exemplo de transformador treinado usando
tf.distribute.MirroredStrategy
- Exemplo de NCF treinado usando
tf.distribute.MirroredStrategy
.
Mais exemplos listados no guia de estratégia de distribuição
Próximos passos
- Leia o guia de estratégia de distribuição .
- Leia o tutorial Treinamento distribuído com loops de treinamento personalizados .
- Visite a seção Desempenho do guia para saber mais sobre outras estratégias e ferramentas que você pode usar para otimizar o desempenho de seus modelos do TensorFlow.