Migrer le mécanisme de tolérance aux pannes

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

La tolérance aux pannes fait référence à un mécanisme d'enregistrement périodique des états des objets traçables, tels que les paramètres et les modèles. Cela vous permet de les récupérer en cas de panne programme/machine pendant la formation.

Ce guide explique d'abord comment ajouter la tolérance aux pannes à l'entraînement avec tf.estimator.Estimator dans TensorFlow 1 en spécifiant l'enregistrement des métriques avec tf.estimator.RunConfig . Ensuite, vous apprendrez à mettre en œuvre la tolérance aux pannes pour l'entraînement dans Tensorflow 2 de deux manières :

Ces deux méthodes sauvegardent et restaurent les états d'entraînement dans les fichiers de point de contrôle .

Installer

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1 : Enregistrer les points de contrôle avec tf.estimator.RunConfig

Dans TensorFlow 1, vous pouvez configurer un tf.estimator pour enregistrer des points de contrôle à chaque étape en configurant tf.estimator.RunConfig .

Dans cet exemple, commencez par écrire un crochet qui génère artificiellement une erreur lors du cinquième point de contrôle :

class InterruptHook(tf1.train.SessionRunHook):
  # A hook for artificially interrupting training.
  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step == 5:
      raise RuntimeError('Interruption')

Ensuite, configurez tf.estimator.Estimator pour enregistrer chaque point de contrôle et utiliser l'ensemble de données MNIST :

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

Commencez à former le modèle. Une exception artificielle sera levée par le crochet que vous avez défini précédemment.

try:
  classifier.train(input_fn=train_input_fn,
                   hooks=[InterruptHook()],
                   max_steps=10)
except Exception as e:
  print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:loss = 118.92192, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpv15yxr9g/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
RuntimeError:Interruption

Reconstruisez tf.estimator.Estimator à l'aide du dernier point de contrôle enregistré et poursuivez l'entraînement :

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)
classifier.train(input_fn=train_input_fn,
                   max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpv15yxr9g/model.ckpt-6
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1161: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:loss = 105.44863, step = 6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 100.47882.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fcfe8165150>

TensorFlow 2 : sauvegardez et restaurez avec un rappel et Model.fit

Dans TensorFlow 2, si vous utilisez l'API Keras Model.fit pour l'entraînement, vous pouvez fournir le rappel tf.keras.callbacks.BackupAndRestore pour ajouter la fonctionnalité de tolérance aux pannes.

Pour aider à démontrer cela, commençons d'abord par définir une classe de rappel qui génère artificiellement une erreur lors du cinquième point de contrôle :

class InterruptingCallback(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def on_epoch_end(self, epoch, log=None):
    if epoch == 4:
      raise RuntimeError('Interruption')

Ensuite, définissez et instanciez un modèle Keras simple, définissez la fonction de perte, appelez Model.compile et configurez un rappel tf.keras.callbacks.BackupAndRestore qui enregistrera les points de contrôle dans un répertoire temporaire :

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir = log_dir
)

Maintenant, commencez à entraîner le modèle avec Model.fit . Lors de l'entraînement, les points de contrôle seront sauvegardés grâce au backup_restore_callback défini ci-dessus, tandis que l' InterruptingCallback lèvera une exception artificielle pour simuler une panne.

try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.2186 - accuracy: 0.9352 - val_loss: 0.1267 - val_accuracy: 0.9615
Epoch 2/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0967 - accuracy: 0.9700 - val_loss: 0.0910 - val_accuracy: 0.9718
Epoch 3/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0687 - accuracy: 0.9784 - val_loss: 0.0679 - val_accuracy: 0.9797
Epoch 4/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0527 - accuracy: 0.9829 - val_loss: 0.0623 - val_accuracy: 0.9814
Epoch 5/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0434 - accuracy: 0.9857RuntimeError:Interruption

Ensuite, instanciez le modèle Keras, appelez Model.compile et continuez à entraîner le modèle avec Model.fit à partir d'un point de contrôle précédemment enregistré :

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 6/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0370 - accuracy: 0.9879 - val_loss: 0.0732 - val_accuracy: 0.9791
Epoch 7/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0306 - accuracy: 0.9898 - val_loss: 0.0601 - val_accuracy: 0.9827
Epoch 8/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0655 - val_accuracy: 0.9819
Epoch 9/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0244 - accuracy: 0.9918 - val_loss: 0.0746 - val_accuracy: 0.9812
Epoch 10/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0221 - accuracy: 0.9923 - val_loss: 0.0818 - val_accuracy: 0.9813
<keras.callbacks.History at 0x7fcfe0647350>

TensorFlow 2 : écrivez des points de contrôle manuels avec une boucle d'entraînement personnalisée

Si vous utilisez une boucle d'entraînement personnalisée dans TensorFlow 2, vous pouvez implémenter un mécanisme de tolérance aux pannes avec les API tf.train.Checkpoint et tf.train.CheckpointManager .

Cet exemple montre comment :

  • Utilisez un objet tf.train.Checkpoint pour créer manuellement un point de contrôle, où les objets traçables que vous souhaitez enregistrer sont définis en tant qu'attributs.
  • Utilisez un tf.train.CheckpointManager pour gérer plusieurs points de contrôle.

Commencez par définir et instancier le modèle Keras, l'optimiseur et la fonction de perte. Ensuite, créez un Checkpoint qui gère deux objets avec des états traçables (le modèle et l'optimiseur), ainsi qu'un CheckpointManager pour la journalisation et la conservation de plusieurs points de contrôle dans un répertoire temporaire.

model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, log_dir, max_to_keep=2)

Maintenant, implémentez une boucle d'entraînement personnalisée où, après la première époque, chaque fois qu'une nouvelle époque démarre, le dernier point de contrôle est chargé :

for epoch in range(epochs):
  if epoch > 0:
      tf.train.load_checkpoint(save_path)
  print(f"\nStart of epoch {epoch}")

  for step in range(steps_per_epoch):
    with tf.GradientTape() as tape:

      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)

      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))

    save_path = checkpoint_manager.save()
    print(f"Checkpoint saved to {save_path}")
    print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-1
Training loss at step 0: 2.3636362552642822
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-2
Training loss at step 1: 2.3626415729522705
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-3
Training loss at step 2: 2.3613197803497314
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-4
Training loss at step 3: 2.360600233078003
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-5
Training loss at step 4: 2.3589422702789307

Start of epoch 1
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-6
Training loss at step 0: 2.3563339710235596
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-7
Training loss at step 1: 2.3568854331970215
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-8
Training loss at step 2: 2.354109287261963
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-9
Training loss at step 3: 2.3532731533050537
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-10
Training loss at step 4: 2.351112127304077

Start of epoch 2
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-11
Training loss at step 0: 2.348905563354492
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-12
Training loss at step 1: 2.349478006362915
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-13
Training loss at step 2: 2.3487260341644287
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-14
Training loss at step 3: 2.345991611480713
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-15
Training loss at step 4: 2.3451104164123535

Start of epoch 3
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-16
Training loss at step 0: 2.3441312313079834
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-17
Training loss at step 1: 2.341529130935669
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-18
Training loss at step 2: 2.342329263687134
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-19
Training loss at step 3: 2.340449571609497
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-20
Training loss at step 4: 2.3367927074432373

Start of epoch 4
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-21
Training loss at step 0: 2.3366076946258545
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-22
Training loss at step 1: 2.335028886795044
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-23
Training loss at step 2: 2.3338520526885986
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-24
Training loss at step 3: 2.3345272541046143
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-25
Training loss at step 4: 2.332385301589966

Prochaines étapes

Pour en savoir plus sur la tolérance aux pannes et les points de contrôle dans TensorFlow 2, consultez la documentation suivante :

Vous pouvez également trouver le matériel suivant lié à la formation distribuée utile :