![]() | ![]() | ![]() | ![]() |
Gambaran
API tf.distribute.Strategy
menyediakan abstraksi untuk mendistribusikan pelatihan Anda ke beberapa unit pemrosesan. Sasarannya adalah untuk memungkinkan pengguna mengaktifkan pelatihan terdistribusi menggunakan model dan kode pelatihan yang ada, dengan sedikit perubahan.
Tutorial ini menggunakan tf.distribute.MirroredStrategy
, yang melakukan replikasi dalam grafik dengan pelatihan sinkron pada banyak GPU di satu mesin. Pada dasarnya, ini menyalin semua variabel model ke setiap prosesor. Kemudian, ini menggunakan all-reduce untuk menggabungkan gradien dari semua prosesor dan menerapkan nilai gabungan ke semua salinan model.
MirroredStrategy
adalah salah satu dari beberapa strategi distribusi yang tersedia di inti TensorFlow. Anda dapat membaca tentang lebih banyak strategi di panduan strategi distribusi .
API Keras
Contoh ini menggunakan tf.keras
API untuk membuat model dan loop pelatihan. Untuk loop pelatihan kustom, lihat tf.distribute.Strategy dengan tutorial loop pelatihan .
Ketergantungan impor
# Import TensorFlow and TensorFlow Datasets
import tensorflow_datasets as tfds
import tensorflow as tf
import os
print(tf.__version__)
2.3.0
Unduh kumpulan data
Unduh kumpulan data MNIST dan muat dari Kumpulan Data TensorFlow . Ini mengembalikan tf.data
data dalam format tf.data
.
Menyetel with_info
ke True
menyertakan metadata untuk seluruh kumpulan data, yang disimpan di sini ke info
. Antara lain, objek metadata ini mencakup jumlah train dan contoh pengujian.
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`. Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Tentukan strategi distribusi
Buat objek MirroredStrategy
. Ini akan menangani distribusi, dan menyediakan pengelola konteks ( tf.distribute.MirroredStrategy.scope
) untuk membangun model Anda di dalamnya.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1
Siapkan pipeline masukan
Saat melatih model dengan beberapa GPU, Anda dapat menggunakan daya komputasi ekstra secara efektif dengan meningkatkan ukuran batch. Secara umum, gunakan ukuran batch terbesar yang sesuai dengan memori GPU, dan sesuaikan kecepatan pemelajarannya.
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
Nilai piksel, yaitu 0-255, harus dinormalisasi ke kisaran 0-1 . Tentukan skala ini dalam sebuah fungsi.
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Terapkan fungsi ini ke data pelatihan dan pengujian, acak data pelatihan, dan kelompokkan untuk pelatihan . Perhatikan bahwa kami juga menyimpan cache dalam memori dari data pelatihan untuk meningkatkan kinerja.
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
Buat modelnya
Buat dan kompilasi model Keras dalam konteks strategy.scope
.
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
Tentukan callback
Callback yang digunakan di sini adalah:
- TensorBoard : Callback ini menulis log untuk TensorBoard yang memungkinkan Anda memvisualisasikan grafik.
- Model Checkpoint : Callback ini menyimpan model setelah setiap periode.
- Penjadwal Kecepatan Pembelajaran : Dengan menggunakan panggilan balik ini, Anda dapat menjadwalkan kecepatan pembelajaran untuk berubah setelah setiap epoch / batch.
Untuk tujuan ilustrasi, tambahkan callback cetak untuk menampilkan kecepatan pembelajaran di notebook.
# Define the checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
model.optimizer.lr.numpy()))
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
Latih dan evaluasi
Sekarang, latih model dengan cara biasa, memanggil fit
pada model dan meneruskan set data yang dibuat di awal tutorial. Langkah ini sama baik Anda mendistribusikan pelatihan atau tidak.
model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 1/938 [..............................] - ETA: 0s - loss: 2.3083 - accuracy: 0.0156WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks. WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks. 932/938 [============================>.] - ETA: 0s - loss: 0.1947 - accuracy: 0.9441 Learning rate for epoch 1 is 0.0010000000474974513 938/938 [==============================] - 4s 4ms/step - loss: 0.1939 - accuracy: 0.9442 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 2/12 935/938 [============================>.] - ETA: 0s - loss: 0.0636 - accuracy: 0.9811 Learning rate for epoch 2 is 0.0010000000474974513 938/938 [==============================] - 2s 3ms/step - loss: 0.0634 - accuracy: 0.9812 Epoch 3/12 936/938 [============================>.] - ETA: 0s - loss: 0.0438 - accuracy: 0.9864 Learning rate for epoch 3 is 0.0010000000474974513 938/938 [==============================] - 2s 3ms/step - loss: 0.0439 - accuracy: 0.9864 Epoch 4/12 937/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9936 Learning rate for epoch 4 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0234 - accuracy: 0.9936 Epoch 5/12 932/938 [============================>.] - ETA: 0s - loss: 0.0204 - accuracy: 0.9948 Learning rate for epoch 5 is 9.999999747378752e-05 938/938 [==============================] - 3s 3ms/step - loss: 0.0204 - accuracy: 0.9948 Epoch 6/12 919/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9951 Learning rate for epoch 6 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0187 - accuracy: 0.9951 Epoch 7/12 921/938 [============================>.] - ETA: 0s - loss: 0.0172 - accuracy: 0.9960 Learning rate for epoch 7 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0171 - accuracy: 0.9960 Epoch 8/12 931/938 [============================>.] - ETA: 0s - loss: 0.0147 - accuracy: 0.9970 Learning rate for epoch 8 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0147 - accuracy: 0.9970 Epoch 9/12 938/938 [==============================] - ETA: 0s - loss: 0.0144 - accuracy: 0.9970 Learning rate for epoch 9 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0144 - accuracy: 0.9970 Epoch 10/12 924/938 [============================>.] - ETA: 0s - loss: 0.0143 - accuracy: 0.9971 Learning rate for epoch 10 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0142 - accuracy: 0.9971 Epoch 11/12 937/938 [============================>.] - ETA: 0s - loss: 0.0140 - accuracy: 0.9972 Learning rate for epoch 11 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0140 - accuracy: 0.9972 Epoch 12/12 923/938 [============================>.] - ETA: 0s - loss: 0.0139 - accuracy: 0.9973 Learning rate for epoch 12 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0139 - accuracy: 0.9973 <tensorflow.python.keras.callbacks.History at 0x7f50a0d94780>
Seperti yang Anda lihat di bawah, pos pemeriksaan sedang diselamatkan.
# check the checkpoint directory
ls {checkpoint_dir}
checkpoint ckpt_4.data-00000-of-00001 ckpt_1.data-00000-of-00001 ckpt_4.index ckpt_1.index ckpt_5.data-00000-of-00001 ckpt_10.data-00000-of-00001 ckpt_5.index ckpt_10.index ckpt_6.data-00000-of-00001 ckpt_11.data-00000-of-00001 ckpt_6.index ckpt_11.index ckpt_7.data-00000-of-00001 ckpt_12.data-00000-of-00001 ckpt_7.index ckpt_12.index ckpt_8.data-00000-of-00001 ckpt_2.data-00000-of-00001 ckpt_8.index ckpt_2.index ckpt_9.data-00000-of-00001 ckpt_3.data-00000-of-00001 ckpt_9.index ckpt_3.index
Untuk melihat performa model, muat checkpoint terbaru dan panggil evaluate
pada data pengujian.
Panggil evaluate
seperti sebelumnya menggunakan set data yang sesuai.
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 6ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
Untuk melihat hasilnya, Anda dapat mendownload dan melihat log TensorBoard di terminal.
$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K 4.0K train
Ekspor ke SavedModel
Ekspor grafik dan variabel ke format SavedModel platform-agnostik. Setelah model Anda disimpan, Anda dapat memuatnya dengan atau tanpa scope.
path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: saved_model/assets INFO:tensorflow:Assets written to: saved_model/assets
Muat model tanpa strategy.scope
.
unreplicated_model = tf.keras.models.load_model(path)
unreplicated_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 3ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
Muat model dengan strategy.scope
.
with strategy.scope():
replicated_model = tf.keras.models.load_model(path)
replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 4ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
Contoh dan Tutorial
Berikut beberapa contoh penggunaan strategi distribusi dengan keras fit / compile:
- Contoh transformator dilatih menggunakan
tf.distribute.MirroredStrategy
- Contoh NCF dilatih menggunakan
tf.distribute.MirroredStrategy
.
Lebih banyak contoh tercantum dalam panduan strategi Distribusi
Langkah selanjutnya
- Baca panduan strategi distribusi .
- Baca tutorial Pelatihan Terdistribusi dengan Loop Pelatihan Kustom .
- Kunjungi bagian Performance dalam panduan untuk mempelajari lebih lanjut tentang strategi lain dan alat-alat yang dapat digunakan untuk mengoptimalkan kinerja model TensorFlow Anda.