Niestandardowe szkolenie z tf.distribute.Strategy

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ten samouczek pokazuje, jak używać tf.distribute.Strategy z niestandardowymi pętlami treningowymi. Wytrenujemy prosty model CNN na zbiorze danych mody MNIST. Zbiór danych MNIST mody zawiera 60000 obrazów pociągów o rozmiarze 28 x 28 i 10000 obrazów testowych o rozmiarze 28 x 28.

Używamy niestandardowych pętli treningowych do trenowania naszego modelu, ponieważ zapewniają nam elastyczność i większą kontrolę nad treningiem. Co więcej, łatwiej jest debugować model i pętlę treningową.

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2.8.0-rc1

Pobierz zestaw danych mody MNIST

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

Stwórz strategię dystrybucji zmiennych i wykresu

Jak działa strategia tf.distribute.MirroredStrategy ?

  • Wszystkie zmienne i wykres modelu są replikowane na replikach.
  • Dane wejściowe są równomiernie rozłożone w replikach.
  • Każda replika oblicza straty i gradienty otrzymanego sygnału wejściowego.
  • Gradienty są synchronizowane we wszystkich replikach poprzez ich sumowanie.
  • Po synchronizacji ta sama aktualizacja jest wykonywana na kopiach zmiennych w każdej replice.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Konfiguracja potoku wejściowego

Wyeksportuj wykres i zmienne do formatu SavedModel niezależnego od platformy. Po zapisaniu modelu możesz go wczytać z zakresem lub bez niego.

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

Twórz zbiory danych i rozpowszechniaj je:

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2022-01-26 05:45:53.991501: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 60000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

2022-01-26 05:45:54.034762: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 10000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:3"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

Stwórz model

Utwórz model za pomocą tf.keras.Sequential . W tym celu można również użyć interfejsu API Model Subclassing.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Zdefiniuj funkcję straty

Zwykle na pojedynczej maszynie z 1 GPU/CPU straty są dzielone przez liczbę przykładów w partii danych wejściowych.

Jak więc obliczyć stratę przy użyciu tf.distribute.Strategy ?

  • Załóżmy na przykład, że masz 4 GPU i wielkość partii 64. Jedna partia danych wejściowych jest rozdzielana na repliki (4 GPU), każda replika otrzymuje dane wejściowe o rozmiarze 16.

  • Model na każdej replice wykonuje przejście do przodu z odpowiednimi danymi wejściowymi i oblicza stratę. Teraz zamiast dzielić stratę przez liczbę przykładów w odpowiednich danych wejściowych (BATCH_SIZE_PER_REPLICA = 16), stratę należy podzielić przez GLOBAL_BATCH_SIZE (64).

Czemu to robić?

  • Należy to zrobić, ponieważ po obliczeniu gradientów dla każdej repliki są one synchronizowane w replikach poprzez ich zsumowanie .

Jak to zrobić w TensorFlow?

  • Jeśli piszesz niestandardową pętlę treningową, tak jak w tym samouczku, powinieneś zsumować straty na przykład i podzielić sumę przez GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) lub możesz użyć tf.nn.compute_average_loss , która przyjmuje jako argumenty straty na przykład, opcjonalne wagi próbek i GLOBAL_BATCH_SIZE i zwraca skalowaną stratę.

  • Jeśli używasz w swoim modelu strat regularyzacji, musisz przeskalować wartość strat według liczby replik. Możesz to zrobić za pomocą funkcji tf.nn.scale_regularization_loss .

  • Używanie tf.reduce_mean nie jest zalecane. W ten sposób dzieli się stratę przez rzeczywistą wielkość partii na replikę, która może się różnić w zależności od kroku.

  • Ta redukcja i skalowanie odbywa się automatycznie w model.compile i model.fit

  • Jeśli używasz klas tf.keras.losses (jak w poniższym przykładzie), redukcja strat musi być wyraźnie określona jako jedna z NONE lub SUM . AUTO i SUM_OVER_BATCH_SIZE są niedozwolone, gdy są używane z tf.distribute.Strategy . AUTO jest niedozwolone, ponieważ użytkownik powinien wyraźnie zastanowić się, jaką redukcję chce się upewnić, że jest ona poprawna w przypadku rozproszonym. SUM_OVER_BATCH_SIZE jest niedozwolona, ​​ponieważ obecnie dzieliłaby tylko według rozmiaru partii replik, a dzielenie według liczby replik pozostawiłoby użytkownikowi, co może być łatwe do przeoczenia. Więc zamiast tego prosimy użytkownika, aby sam dokonał redukcji.

  • Jeśli labels są wielowymiarowe, per_example_loss dla liczby elementów w każdej próbce. Na przykład, jeśli kształt predictions to (batch_size, H, W, n_classes) a labels to (batch_size, H, W) , będziesz musiał zaktualizować per_example_loss , np. per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

Zdefiniuj metryki, aby śledzić straty i dokładność

Te metryki śledzą utratę testów, szkolenia i dokładność testów. Możesz użyć .result() , aby uzyskać skumulowane statystyki w dowolnym momencie.

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Pętla treningowa

# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()
Epoch 1, Loss: 0.5106383562088013, Accuracy: 81.77999877929688, Test Loss: 0.39399346709251404, Test Accuracy: 85.79000091552734
Epoch 2, Loss: 0.3362727463245392, Accuracy: 87.91333770751953, Test Loss: 0.35871225595474243, Test Accuracy: 86.7699966430664
Epoch 3, Loss: 0.2928692400455475, Accuracy: 89.2683334350586, Test Loss: 0.2999486029148102, Test Accuracy: 89.04000091552734
Epoch 4, Loss: 0.2605818510055542, Accuracy: 90.41999816894531, Test Loss: 0.28474125266075134, Test Accuracy: 89.47000122070312
Epoch 5, Loss: 0.23641237616539001, Accuracy: 91.32166290283203, Test Loss: 0.26421546936035156, Test Accuracy: 90.41000366210938
Epoch 6, Loss: 0.2192477434873581, Accuracy: 91.90499877929688, Test Loss: 0.2650589942932129, Test Accuracy: 90.4800033569336
Epoch 7, Loss: 0.20016911625862122, Accuracy: 92.66999816894531, Test Loss: 0.25025954842567444, Test Accuracy: 90.9000015258789
Epoch 8, Loss: 0.18381091952323914, Accuracy: 93.26499938964844, Test Loss: 0.2585820257663727, Test Accuracy: 90.95999908447266
Epoch 9, Loss: 0.1699329912662506, Accuracy: 93.67500305175781, Test Loss: 0.26234227418899536, Test Accuracy: 91.0199966430664
Epoch 10, Loss: 0.15756534039974213, Accuracy: 94.16333770751953, Test Loss: 0.25516414642333984, Test Accuracy: 90.93000030517578

Rzeczy do odnotowania w powyższym przykładzie:

Przywróć najnowszy punkt kontrolny i przetestuj

Model z punktem kontrolnym tf.distribute.Strategy można przywrócić ze strategią lub bez niej.

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 91.0199966430664

Alternatywne sposoby iteracji po zbiorze danych

Korzystanie z iteratorów

Jeśli chcesz iterować przez określoną liczbę kroków, a nie przez cały zestaw danych, możesz utworzyć iterator za pomocą wywołania iter i jawnego wywołania next iteratora. Możesz wybrać iterację po zbiorze danych zarówno wewnątrz, jak i na zewnątrz funkcji tf.function. Oto mały fragment pokazujący iterację zbioru danych poza funkcją tf.function za pomocą iteratora.

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()
Epoch 10, Loss: 0.17486707866191864, Accuracy: 93.4375
Epoch 10, Loss: 0.12386945635080338, Accuracy: 95.3125
Epoch 10, Loss: 0.16411852836608887, Accuracy: 93.90625
Epoch 10, Loss: 0.10728752613067627, Accuracy: 96.40625
Epoch 10, Loss: 0.11865834891796112, Accuracy: 95.625
Epoch 10, Loss: 0.12875251471996307, Accuracy: 95.15625
Epoch 10, Loss: 0.1189488023519516, Accuracy: 95.625
Epoch 10, Loss: 0.1456708014011383, Accuracy: 95.15625
Epoch 10, Loss: 0.12446556240320206, Accuracy: 95.3125
Epoch 10, Loss: 0.1380888819694519, Accuracy: 95.46875

Iteracja wewnątrz funkcji tf.

Możesz także iterować po całym wejściowym zestawie danych train_dist_dataset wewnątrz funkcji tf., używając konstrukcji for x in ... lub tworząc iteratory, tak jak to zrobiliśmy powyżej. Poniższy przykład ilustruje zawijanie jednej epoki szkolenia w tf.function i iterację po train_dist_dataset wewnątrz funkcji.

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:449: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
Epoch 1, Loss: 0.14398494362831116, Accuracy: 94.63999938964844
Epoch 2, Loss: 0.13246288895606995, Accuracy: 94.97333526611328
Epoch 3, Loss: 0.11922841519117355, Accuracy: 95.63833618164062
Epoch 4, Loss: 0.11084160208702087, Accuracy: 95.99333190917969
Epoch 5, Loss: 0.10420522093772888, Accuracy: 96.0816650390625
Epoch 6, Loss: 0.09215126931667328, Accuracy: 96.63500213623047
Epoch 7, Loss: 0.0878651961684227, Accuracy: 96.67666625976562
Epoch 8, Loss: 0.07854588329792023, Accuracy: 97.09333038330078
Epoch 9, Loss: 0.07217177003622055, Accuracy: 97.34833526611328
Epoch 10, Loss: 0.06753655523061752, Accuracy: 97.48999786376953

Śledzenie utraty treningu w replikach

Nie zalecamy używania tf.metrics.Mean do śledzenia utraty treningu w różnych replikach ze względu na przeprowadzane obliczenia skalowania strat.

Na przykład, jeśli prowadzisz zadanie szkoleniowe o następujących cechach:

  • Dwie repliki
  • Na każdej replice przetwarzane są dwie próbki
  • Wynikowe wartości strat: [2, 3] i [4, 5] na każdej replice
  • Globalna wielkość partii = 4

Dzięki skalowaniu strat można obliczyć wartość strat na próbkę w każdej replice, dodając wartości strat, a następnie dzieląc je przez globalny rozmiar partii. W tym przypadku: (2 + 3) / 4 = 1.25 i (4 + 5) / 4 = 2.25 .

Jeśli użyjesz tf.metrics.Mean do śledzenia strat w dwóch replikach, wynik będzie inny. W tym przykładzie otrzymujesz total 3,50 i count 2, co daje total / count = 1,75, gdy result() jest wywoływany w metryce. Strata obliczona za pomocą tf.keras.Metrics jest skalowana przez dodatkowy czynnik, który jest równy liczbie zsynchronizowanych replik.

Przewodnik i przykłady

Oto kilka przykładów wykorzystania strategii dystrybucji z niestandardowymi pętlami treningowymi:

  1. Rozproszony przewodnik szkoleniowy
  2. Przykład DenseNet przy użyciu MirroredStrategy .
  3. Przykład BERT przeszkolony przy użyciu MirroredStrategy i TPUStrategy . Ten przykład jest szczególnie pomocny w zrozumieniu, jak ładować z punktu kontrolnego i generować okresowe punkty kontrolne podczas szkolenia rozproszonego itp.
  4. Przykład NCF wyszkolony przy użyciu MirroredStrategy , który można włączyć za pomocą flagi keras_use_ctl .
  5. Przykład NMT przeszkolony przy użyciu MirroredStrategy .

Więcej przykładów można znaleźć w przewodniku po strategii dystrybucji .

Następne kroki