Rejoignez TensorFlow à Google I/O, du 11 au 12 mai. Inscrivez-vous maintenant

Formation personnalisée avec tf.distribute.Strategy

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Ce didacticiel montre comment utiliser tf.distribute.Strategy avec des boucles de formation personnalisées. Nous formerons un modèle CNN simple sur le jeu de données fashion MNIST. L'ensemble de données de la mode MNIST contient 60000 images de train de taille 28 x 28 et 10000 images de test de taille 28 x 28.

Nous utilisons des boucles d'entraînement personnalisées pour entraîner notre modèle car elles nous donnent de la flexibilité et un plus grand contrôle sur l'entraînement. De plus, il est plus facile de déboguer le modèle et la boucle d'apprentissage.

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2.8.0-rc1

Téléchargez le jeu de données MNIST sur la mode

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

Créer une stratégie pour distribuer les variables et le graphique

Comment fonctionne la stratégie tf.distribute.MirroredStrategy ?

  • Toutes les variables et le graphe du modèle sont répliqués sur les répliques.
  • L'entrée est répartie uniformément entre les répliques.
  • Chaque réplique calcule la perte et les gradients pour l'entrée qu'elle a reçue.
  • Les dégradés sont synchronisés sur toutes les répliques en les additionnant.
  • Après la synchronisation, la même mise à jour est effectuée sur les copies des variables sur chaque réplique.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Configurer le pipeline d'entrée

Exportez le graphique et les variables au format SavedModel indépendant de la plate-forme. Une fois votre modèle enregistré, vous pouvez le charger avec ou sans la portée.

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

Créez les jeux de données et distribuez-les :

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2022-01-26 05:45:53.991501: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 60000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

2022-01-26 05:45:54.034762: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 10000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:3"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

Créer le modèle

Créez un modèle à l'aide tf.keras.Sequential . Vous pouvez également utiliser l'API Model Subclassing pour ce faire.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Définir la fonction de perte

Normalement, sur une seule machine avec 1 GPU/CPU, la perte est divisée par le nombre d'exemples dans le lot d'entrée.

Alors, comment la perte doit-elle être calculée lors de l'utilisation d'un tf.distribute.Strategy ?

  • Par exemple, supposons que vous ayez 4 GPU et une taille de lot de 64. Un lot d'entrée est distribué sur les répliques (4 GPU), chaque réplique recevant une entrée de taille 16.

  • Le modèle de chaque réplique effectue un passage vers l'avant avec son entrée respective et calcule la perte. Maintenant, au lieu de diviser la perte par le nombre d'exemples dans son entrée respective (BATCH_SIZE_PER_REPLICA = 16), la perte doit être divisée par GLOBAL_BATCH_SIZE (64).

Pourquoi faire ceci?

  • Cela doit être fait car une fois les gradients calculés sur chaque réplique, ils sont synchronisés sur les répliques en les additionnant .

Comment faire cela dans TensorFlow ?

  • Si vous écrivez une boucle d'entraînement personnalisée, comme dans ce didacticiel, vous devez additionner les pertes par exemple et diviser la somme par GLOBAL_BATCH_SIZE : scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) ou vous pouvez utiliser tf.nn.compute_average_loss qui prend la perte par exemple, les poids d'échantillon facultatifs et GLOBAL_BATCH_SIZE comme arguments et renvoie la perte mise à l'échelle.

  • Si vous utilisez des pertes de régularisation dans votre modèle, vous devez mettre à l'échelle la valeur de perte en fonction du nombre de répliques. Vous pouvez le faire en utilisant la fonction tf.nn.scale_regularization_loss .

  • L'utilisation tf.reduce_mean n'est pas recommandée. Cela divise la perte par la taille réelle du lot de répliques, qui peut varier d'une étape à l'autre.

  • Cette réduction et cette mise à l'échelle se font automatiquement dans keras model.compile et model.fit

  • Si vous utilisez les classes tf.keras.losses (comme dans l'exemple ci-dessous), la réduction de perte doit être explicitement spécifiée comme étant NONE ou SUM . AUTO et SUM_OVER_BATCH_SIZE sont interdits lorsqu'ils sont utilisés avec tf.distribute.Strategy . AUTO n'est pas autorisé car l'utilisateur doit réfléchir explicitement à la réduction qu'il souhaite pour s'assurer qu'elle est correcte dans le cas distribué. SUM_OVER_BATCH_SIZE n'est pas autorisé car actuellement, il ne diviserait que par taille de lot de réplicas et laisserait la division par le nombre de réplicas à l'utilisateur, ce qui pourrait être facile à manquer. Donc, à la place, nous demandons à l'utilisateur de faire lui-même explicitement la réduction.

  • Si les labels sont multidimensionnelles, faites la moyenne de la per_example_loss par exemple sur le nombre d'éléments dans chaque échantillon. Par exemple, si la forme des predictions est (batch_size, H, W, n_classes) et labels sont (batch_size, H, W) , vous devrez mettre à jour per_example_loss comme : per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

Définir les métriques pour suivre la perte et la précision

Ces métriques suivent la perte de test et la précision de la formation et du test. Vous pouvez utiliser .result() pour obtenir les statistiques accumulées à tout moment.

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Boucle d'entraînement

# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()
Epoch 1, Loss: 0.5106383562088013, Accuracy: 81.77999877929688, Test Loss: 0.39399346709251404, Test Accuracy: 85.79000091552734
Epoch 2, Loss: 0.3362727463245392, Accuracy: 87.91333770751953, Test Loss: 0.35871225595474243, Test Accuracy: 86.7699966430664
Epoch 3, Loss: 0.2928692400455475, Accuracy: 89.2683334350586, Test Loss: 0.2999486029148102, Test Accuracy: 89.04000091552734
Epoch 4, Loss: 0.2605818510055542, Accuracy: 90.41999816894531, Test Loss: 0.28474125266075134, Test Accuracy: 89.47000122070312
Epoch 5, Loss: 0.23641237616539001, Accuracy: 91.32166290283203, Test Loss: 0.26421546936035156, Test Accuracy: 90.41000366210938
Epoch 6, Loss: 0.2192477434873581, Accuracy: 91.90499877929688, Test Loss: 0.2650589942932129, Test Accuracy: 90.4800033569336
Epoch 7, Loss: 0.20016911625862122, Accuracy: 92.66999816894531, Test Loss: 0.25025954842567444, Test Accuracy: 90.9000015258789
Epoch 8, Loss: 0.18381091952323914, Accuracy: 93.26499938964844, Test Loss: 0.2585820257663727, Test Accuracy: 90.95999908447266
Epoch 9, Loss: 0.1699329912662506, Accuracy: 93.67500305175781, Test Loss: 0.26234227418899536, Test Accuracy: 91.0199966430664
Epoch 10, Loss: 0.15756534039974213, Accuracy: 94.16333770751953, Test Loss: 0.25516414642333984, Test Accuracy: 90.93000030517578

Choses à noter dans l'exemple ci-dessus :

Restaurer le dernier point de contrôle et tester

Un modèle contrôlé avec un tf.distribute.Strategy peut être restauré avec ou sans stratégie.

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 91.0199966430664

Autres façons d'itérer sur un ensemble de données

Utiliser des itérateurs

Si vous souhaitez itérer sur un nombre donné d'étapes et non sur l'intégralité de l'ensemble de données, vous pouvez créer un itérateur à l'aide de l'appel iter et appeler explicitement next sur l'itérateur. Vous pouvez choisir d'itérer sur l'ensemble de données à l'intérieur et à l'extérieur de la fonction tf. Voici un petit extrait démontrant l'itération de l'ensemble de données en dehors de la tf.function à l'aide d'un itérateur.

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()
Epoch 10, Loss: 0.17486707866191864, Accuracy: 93.4375
Epoch 10, Loss: 0.12386945635080338, Accuracy: 95.3125
Epoch 10, Loss: 0.16411852836608887, Accuracy: 93.90625
Epoch 10, Loss: 0.10728752613067627, Accuracy: 96.40625
Epoch 10, Loss: 0.11865834891796112, Accuracy: 95.625
Epoch 10, Loss: 0.12875251471996307, Accuracy: 95.15625
Epoch 10, Loss: 0.1189488023519516, Accuracy: 95.625
Epoch 10, Loss: 0.1456708014011383, Accuracy: 95.15625
Epoch 10, Loss: 0.12446556240320206, Accuracy: 95.3125
Epoch 10, Loss: 0.1380888819694519, Accuracy: 95.46875

Itérer à l'intérieur d'un tf.function

Vous pouvez également itérer sur l'ensemble de l'entrée train_dist_dataset à l'intérieur d'une fonction tf en utilisant la construction for x in ... ou en créant des itérateurs comme nous l'avons fait ci-dessus. L'exemple ci-dessous illustre l'encapsulation d'une époque de formation dans un tf.function et l'itération sur train_dist_dataset à l'intérieur de la fonction.

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:449: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
Epoch 1, Loss: 0.14398494362831116, Accuracy: 94.63999938964844
Epoch 2, Loss: 0.13246288895606995, Accuracy: 94.97333526611328
Epoch 3, Loss: 0.11922841519117355, Accuracy: 95.63833618164062
Epoch 4, Loss: 0.11084160208702087, Accuracy: 95.99333190917969
Epoch 5, Loss: 0.10420522093772888, Accuracy: 96.0816650390625
Epoch 6, Loss: 0.09215126931667328, Accuracy: 96.63500213623047
Epoch 7, Loss: 0.0878651961684227, Accuracy: 96.67666625976562
Epoch 8, Loss: 0.07854588329792023, Accuracy: 97.09333038330078
Epoch 9, Loss: 0.07217177003622055, Accuracy: 97.34833526611328
Epoch 10, Loss: 0.06753655523061752, Accuracy: 97.48999786376953

Suivi de la perte d'entraînement entre les répliques

Nous vous déconseillons d'utiliser tf.metrics.Mean pour suivre la perte d'entraînement sur différents réplicas, en raison du calcul de mise à l'échelle des pertes qui est effectué.

Par exemple, si vous exécutez une tâche d'entraînement avec les caractéristiques suivantes :

  • Deux répliques
  • Deux échantillons sont traités sur chaque réplique
  • Valeurs de perte résultantes : [2, 3] et [4, 5] sur chaque réplique
  • Taille de lot globale = 4

Avec la mise à l'échelle des pertes, vous calculez la valeur de perte par échantillon sur chaque réplique en ajoutant les valeurs de perte, puis en divisant par la taille de lot globale. Dans ce cas : (2 + 3) / 4 = 1.25 et (4 + 5) / 4 = 2.25 .

Si vous utilisez tf.metrics.Mean pour suivre la perte sur les deux réplicas, le résultat est différent. Dans cet exemple, vous vous retrouvez avec un total de 3,50 et un count de 2, ce qui donne total / count = 1,75 lorsque result() est appelé sur la métrique. La perte calculée avec tf.keras.Metrics est mise à l'échelle par un facteur supplémentaire égal au nombre de répliques synchronisées.

Guide et exemples

Voici quelques exemples d'utilisation d'une stratégie de distribution avec des boucles d'entraînement personnalisées :

  1. Guide de formation distribué
  2. Exemple DenseNet utilisant MirroredStrategy .
  3. Exemple BERT formé à l'aide de MirroredStrategy et TPUStrategy . Cet exemple est particulièrement utile pour comprendre comment charger à partir d'un point de contrôle et générer des points de contrôle périodiques pendant la formation distribuée, etc.
  4. Exemple NCF formé à l'aide de MirroredStrategy qui peut être activé à l'aide de l'indicateur keras_use_ctl .
  5. Exemple NMT formé à l'aide de MirroredStrategy .

Plus d'exemples répertoriés dans le Guide de stratégie de distribution .

Prochaines étapes

  • Essayez la nouvelle API tf.distribute.Strategy sur vos modèles.
  • Consultez la section Performances du guide pour en savoir plus sur les autres stratégies et outils que vous pouvez utiliser pour optimiser les performances de vos modèles TensorFlow.