Punya pertanyaan? Terhubung dengan komunitas di Forum Kunjungan TensorFlow Forum

Pelatihan khusus dengan tf.distribute.Strategy

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Tutorial ini menunjukkan cara menggunakan tf.distribute.Strategy dengan loop pelatihan khusus. Kami akan melatih model CNN sederhana pada dataset fashion MNIST. Kumpulan data mode MNIST berisi 60.000 gambar kereta berukuran 28 x 28 dan 10.000 gambar uji berukuran 28 x 28.

Kami menggunakan loop pelatihan khusus untuk melatih model kami karena mereka memberi kami fleksibilitas dan kontrol yang lebih besar pada pelatihan. Selain itu, lebih mudah untuk men-debug model dan loop pelatihan.

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2.5.0

Unduh kumpulan data mode MNIST

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

Buat strategi untuk mendistribusikan variabel dan grafik

Bagaimana cara kerja strategi tf.distribute.MirroredStrategy ?

  • Semua variabel dan model grafik direplikasi pada replika.
  • Masukan didistribusikan secara merata di seluruh replika.
  • Setiap replika menghitung kerugian dan gradien untuk input yang diterimanya.
  • Gradien disinkronkan di semua replika dengan menjumlahkannya.
  • Setelah sinkronisasi, pembaruan yang sama dilakukan pada salinan variabel di setiap replika.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Siapkan saluran input

Ekspor grafik dan variabel ke format SavedModel platform-agnostik. Setelah model Anda disimpan, Anda dapat memuatnya dengan atau tanpa cakupan.

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

Buat kumpulan data dan distribusikan:

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

Buat modelnya

Buat model menggunakan tf.keras.Sequential . Anda juga dapat menggunakan Model Subclassing API untuk melakukan ini.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Tentukan fungsi kerugian

Biasanya, pada satu mesin dengan 1 GPU/CPU, kerugian dibagi dengan jumlah contoh dalam kumpulan input.

Jadi, bagaimana cara menghitung kerugian saat menggunakan tf.distribute.Strategy ?

  • Sebagai contoh, katakanlah Anda memiliki 4 GPU dan ukuran batch 64. Satu batch input didistribusikan ke seluruh replika (4 GPU), setiap replika mendapatkan input berukuran 16.

  • Model pada setiap replika melakukan forward pass dengan inputnya masing-masing dan menghitung kerugiannya. Sekarang, alih-alih membagi kerugian dengan jumlah contoh dalam input masing-masing (BATCH_SIZE_PER_REPLICA = 16), kerugian harus dibagi dengan GLOBAL_BATCH_SIZE (64).

Kenapa melakukan ini?

  • Ini perlu dilakukan karena setelah gradien dihitung pada setiap replika, gradien disinkronkan di seluruh replika dengan menjumlahkannya .

Bagaimana cara melakukannya di TensorFlow?

  • Jika Anda menulis loop pelatihan khusus, seperti dalam tutorial ini, Anda harus menjumlahkan kerugian per contoh dan membagi jumlahnya dengan GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) atau Anda dapat menggunakan tf.nn.compute_average_loss yang mengambil kerugian per contoh, bobot sampel opsional, dan GLOBAL_BATCH_SIZE sebagai argumen dan mengembalikan skala kerugian.

  • Jika Anda menggunakan kerugian regularisasi dalam model Anda, maka Anda perlu menskalakan nilai kerugian dengan jumlah replika. Anda dapat melakukannya dengan menggunakan fungsi tf.nn.scale_regularization_loss .

  • Menggunakan tf.reduce_mean tidak disarankan. Melakukannya membagi kerugian dengan ukuran batch aktual per replika yang dapat bervariasi dari langkah ke langkah.

  • Pengurangan dan penskalaan ini dilakukan secara otomatis dalam keras model.compile dan model.fit

  • Jika menggunakan kelastf.keras.losses (seperti pada contoh di bawah), pengurangan kerugian harus secara eksplisit ditentukan menjadi salah satu dari NONE atau SUM . AUTO dan SUM_OVER_BATCH_SIZE tidak diizinkan saat digunakan dengan tf.distribute.Strategy . AUTO tidak diizinkan karena pengguna harus secara eksplisit memikirkan pengurangan apa yang mereka inginkan untuk memastikannya benar dalam kasus terdistribusi. SUM_OVER_BATCH_SIZE tidak diizinkan karena saat ini hanya akan membagi dengan ukuran batch per replika, dan membiarkan pembagian dengan jumlah replika kepada pengguna, yang mungkin mudah dilewatkan. Jadi alih-alih kami meminta pengguna melakukan pengurangan sendiri secara eksplisit.

  • Jika labels multidimensi, rata-rata per_example_loss di seluruh jumlah elemen di setiap sampel. Misalnya, jika bentuk predictions adalah (batch_size, H, W, n_classes) dan labels adalah (batch_size, H, W) , Anda perlu memperbarui per_example_loss seperti: per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

Tentukan metrik untuk melacak kehilangan dan akurasi

Metrik ini melacak kehilangan pengujian dan pelatihan serta akurasi pengujian. Anda dapat menggunakan .result() untuk mendapatkan akumulasi statistik kapan saja.

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Lingkaran pelatihan

# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()
Epoch 1, Loss: 0.5044084787368774, Accuracy: 81.87333679199219, Test Loss: 0.3816865086555481, Test Accuracy: 86.5999984741211
Epoch 2, Loss: 0.3375805616378784, Accuracy: 87.8566665649414, Test Loss: 0.3369813859462738, Test Accuracy: 87.76000213623047
Epoch 3, Loss: 0.2896445095539093, Accuracy: 89.50499725341797, Test Loss: 0.299490362405777, Test Accuracy: 89.22000122070312
Epoch 4, Loss: 0.259074866771698, Accuracy: 90.58833312988281, Test Loss: 0.2881558835506439, Test Accuracy: 89.33000183105469
Epoch 5, Loss: 0.2341146171092987, Accuracy: 91.38999938964844, Test Loss: 0.2916182577610016, Test Accuracy: 89.61000061035156
Epoch 6, Loss: 0.21513047814369202, Accuracy: 92.02333068847656, Test Loss: 0.2755740284919739, Test Accuracy: 89.85000610351562
Epoch 7, Loss: 0.1952667236328125, Accuracy: 92.88333129882812, Test Loss: 0.27464523911476135, Test Accuracy: 90.36000061035156
Epoch 8, Loss: 0.17831537127494812, Accuracy: 93.3566665649414, Test Loss: 0.26432710886001587, Test Accuracy: 90.19000244140625
Epoch 9, Loss: 0.16429665684700012, Accuracy: 93.85333251953125, Test Loss: 0.2659859359264374, Test Accuracy: 91.0999984741211
Epoch 10, Loss: 0.1503313183784485, Accuracy: 94.42166900634766, Test Loss: 0.2602477967739105, Test Accuracy: 91.06999969482422

Hal-hal yang perlu diperhatikan dalam contoh di atas:

Kembalikan pos pemeriksaan dan tes terbaru

Sebuah model checkpoint dengan tf.distribute.Strategy dapat dipulihkan dengan atau tanpa strategi.

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 91.0999984741211

Cara alternatif untuk mengulangi kumpulan data

Menggunakan iterator

Jika Anda ingin mengulangi sejumlah langkah tertentu dan tidak melalui seluruh kumpulan data, Anda dapat membuat iterator menggunakan panggilan iter dan panggilan eksplisit next pada iterator. Anda dapat memilih untuk mengulangi dataset baik di dalam maupun di luar fungsi tf. Berikut adalah cuplikan kecil yang menunjukkan iterasi kumpulan data di luar tf.function menggunakan iterator.

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()
Epoch 10, Loss: 0.14126229286193848, Accuracy: 95.0
Epoch 10, Loss: 0.1343936026096344, Accuracy: 95.0
Epoch 10, Loss: 0.12443388998508453, Accuracy: 94.84375
Epoch 10, Loss: 0.1607474684715271, Accuracy: 94.21875
Epoch 10, Loss: 0.10524413734674454, Accuracy: 96.71875
Epoch 10, Loss: 0.11492376029491425, Accuracy: 96.71875
Epoch 10, Loss: 0.16041627526283264, Accuracy: 94.21875
Epoch 10, Loss: 0.13022005558013916, Accuracy: 94.6875
Epoch 10, Loss: 0.17113295197486877, Accuracy: 93.28125
Epoch 10, Loss: 0.12315043061971664, Accuracy: 95.625

Iterasi di dalam tf.function

Anda juga dapat mengulangi seluruh input train_dist_dataset di dalam tf.function menggunakan konstruksi for x in ... atau dengan membuat iterator seperti yang kita lakukan di atas. Contoh di bawah ini menunjukkan membungkus satu epoch pelatihan dalam tf.function dan mengulangi train_dist_dataset di dalam fungsi.

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()
Epoch 1, Loss: 0.13766956329345703, Accuracy: 94.89666748046875
Epoch 2, Loss: 0.12510614097118378, Accuracy: 95.35166931152344
Epoch 3, Loss: 0.11464647948741913, Accuracy: 95.70333099365234
Epoch 4, Loss: 0.10295023769140244, Accuracy: 96.12000274658203
Epoch 5, Loss: 0.09352775663137436, Accuracy: 96.49666595458984
Epoch 6, Loss: 0.08494547754526138, Accuracy: 96.87166595458984
Epoch 7, Loss: 0.07917638123035431, Accuracy: 97.09166717529297
Epoch 8, Loss: 0.07128290832042694, Accuracy: 97.37833404541016
Epoch 9, Loss: 0.06662175804376602, Accuracy: 97.47999572753906
Epoch 10, Loss: 0.06016768515110016, Accuracy: 97.82833099365234

Melacak kehilangan pelatihan di seluruh replika

Kami tidak menyarankan penggunaan tf.metrics.Mean untuk melacak kerugian pelatihan di berbagai replika, karena perhitungan penskalaan kerugian yang dilakukan.

Misalnya, jika Anda menjalankan pekerjaan pelatihan dengan karakteristik berikut:

  • Dua replika
  • Dua sampel diproses pada setiap replika
  • Nilai kerugian yang dihasilkan: [2, 3] dan [4, 5] pada setiap replika
  • Ukuran batch global = 4

Dengan penskalaan kerugian, Anda menghitung nilai kerugian per sampel pada setiap replika dengan menambahkan nilai kerugian, lalu membaginya dengan ukuran kumpulan global. Dalam hal ini: (2 + 3) / 4 = 1.25 dan (4 + 5) / 4 = 2.25 .

Jika Anda menggunakan tf.metrics.Mean untuk melacak kerugian di kedua replika, hasilnya akan berbeda. Dalam contoh ini, Anda mendapatkan total 3,50 dan count 2, yang menghasilkan total / count = 1,75 saat result() dipanggil pada metrik. Kerugian yang dihitung dengan tf.keras.Metrics diskalakan dengan faktor tambahan yang sama dengan jumlah replika yang disinkronkan.

Panduan dan contoh

Berikut adalah beberapa contoh untuk menggunakan strategi distribusi dengan loop pelatihan khusus:

  1. Panduan pelatihan terdistribusi
  2. Contoh DenseNet menggunakan MirroredStrategy .
  3. Bert contoh dilatih menggunakan MirroredStrategy dan TPUStrategy . Contoh ini sangat membantu untuk memahami cara memuat dari pos pemeriksaan dan menghasilkan pos pemeriksaan berkala selama pelatihan terdistribusi, dll.
  4. Contoh NCF dilatih menggunakan MirroredStrategy yang dapat diaktifkan menggunakan flag keras_use_ctl .
  5. Contoh NMT dilatih menggunakan MirroredStrategy .

Contoh lainnya tercantum dalam panduan strategi Distribusi .

Langkah selanjutnya