Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Migrar el mecanismo de tolerancia a fallas

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

La tolerancia a fallas se refiere a un mecanismo para guardar periódicamente los estados de los objetos rastreables, como parámetros y modelos. Esto le permite recuperarlos en caso de que un programa / máquina falle durante el entrenamiento.

Esta guía muestra cómo agregar primera tolerancia a fallos a los entrenamientos con tf.estimator.Estimator en TensorFlow 1 mediante la especificación de ahorro métrica con tf.estimator.RunConfig . Luego, aprenderá cómo implementar la tolerancia a fallas para el entrenamiento en Tensorflow 2 de dos maneras:

Ambos métodos se creará una copia de seguridad y restaurar los estados de formación en los puestos de control de archivos.

Configuración

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1: guarda puntos de control con tf.estimator.RunConfig

En TensorFlow 1, se puede configurar un tf.estimator para salvar los puestos de control en cada paso mediante la configuración de tf.estimator.RunConfig .

En este ejemplo, comience escribiendo un gancho que arroje un error artificialmente durante el quinto punto de control:

class InterruptHook(tf1.train.SessionRunHook):
  # A hook for artificially interrupting training.
  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step == 5:
      raise RuntimeError('Interruption')

A continuación, configure tf.estimator.Estimator para guardar cada punto de control y utilizar el conjunto de datos MNIST:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpk5_4cfvv', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmp/ipykernel_13774/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmp/ipykernel_13774/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

Empiece a entrenar el modelo. El gancho que definió anteriormente generará una excepción artificial.

try:
  classifier.train(input_fn=train_input_fn,
                   hooks=[InterruptHook()],
                   max_steps=10)
except Exception as e:
  print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:907: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:loss = 117.26719, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpk5_4cfvv/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:971: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
RuntimeError:Interruption

Reconstruir el tf.estimator.Estimator utilizando el último punto de control guardado y continuar el entrenamiento:

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)
classifier.train(input_fn=train_input_fn,
                   max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpk5_4cfvv', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpk5_4cfvv/model.ckpt-6
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1078: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:loss = 103.11247, step = 6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpk5_4cfvv/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 86.68358.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f9e820cff50>

TensorFlow 2: realiza una copia de seguridad y restaura con una devolución de llamada y Model.fit

En TensorFlow 2, si se utiliza el Keras Model.fit API para el entrenamiento, puede proporcionar al tf.keras.callbacks.experimental.BackupAndRestore de devolución de llamada para agregar la funcionalidad de tolerancia a fallos.

Para ayudar a demostrar esto, comencemos primero por definir una clase de devolución de llamada que arroje un error artificialmente durante el quinto punto de control:

class InterruptingCallback(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def on_epoch_end(self, epoch, log=None):
    if epoch == 4:
      raise RuntimeError('Interruption')

A continuación, definir y crear instancias de un modelo simple de Keras, definir la función de pérdida, llamar Model.compile , y establecer un tf.keras.callbacks.experimental.BackupAndRestore devolución de llamada que va a salvar los puestos de control en un directorio temporal:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

backup_restore_callback = tf.keras.callbacks.experimental.BackupAndRestore(
    backup_dir = log_dir
)

Ahora, comenzar a entrenar el modelo con Model.fit . Durante el entrenamiento, los puestos de control se guardarán gracias a la backup_restore_callback definido anteriormente, mientras que el InterruptingCallback lanzará una excepción artificial para simular un fracaso.

try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2167 - accuracy: 0.9352 - val_loss: 0.0944 - val_accuracy: 0.9725
Epoch 2/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0965 - accuracy: 0.9703 - val_loss: 0.0823 - val_accuracy: 0.9735
Epoch 3/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0684 - accuracy: 0.9780 - val_loss: 0.0727 - val_accuracy: 0.9756
Epoch 4/10
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0542 - accuracy: 0.9829 - val_loss: 0.0676 - val_accuracy: 0.9790
Epoch 5/10
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0442 - accuracy: 0.9855 - val_loss: 0.0634 - val_accuracy: 0.9807
RuntimeError:Interruption

A continuación, una instancia del modelo Keras, llame Model.compile , y continuar el entrenamiento del modelo con Model.fit desde un puesto de control previamente guardada:

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 6/10
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0367 - accuracy: 0.9876 - val_loss: 0.0725 - val_accuracy: 0.9794
Epoch 7/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0313 - accuracy: 0.9894 - val_loss: 0.0787 - val_accuracy: 0.9779
Epoch 8/10
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0288 - accuracy: 0.9905 - val_loss: 0.0820 - val_accuracy: 0.9782
Epoch 9/10
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0252 - accuracy: 0.9914 - val_loss: 0.0639 - val_accuracy: 0.9830
Epoch 10/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0235 - accuracy: 0.9924 - val_loss: 0.0788 - val_accuracy: 0.9803
<keras.callbacks.History at 0x7f9e82536990>

TensorFlow 2: escribe puntos de control manuales con un ciclo de entrenamiento personalizado

Si utiliza un bucle de entrenamiento personalizado en TensorFlow 2, se puede implementar un mecanismo de tolerancia a fallos con el tf.train.Checkpoint y tf.train.CheckpointManager API.

Este ejemplo demuestra cómo:

  • Use un tf.train.Checkpoint objeto para crear manualmente un punto de control, donde los objetos del objeto de control que desea guardar se establecen como atributos.
  • Use un tf.train.CheckpointManager para gestionar múltiples puntos de control.

Empiece por definir y crear instancias del modelo de Keras, el optimizador y la función de pérdida. A continuación, crear un Checkpoint que gestiona los dos objetos con los estados de Seguimiento (el modelo y el optimizador), así como un CheckpointManager para el registro y mantenimiento de varios puestos de control en un directorio temporal.

model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, log_dir, max_to_keep=2)

Ahora, implemente un ciclo de entrenamiento personalizado donde después de la primera época, cada vez que comienza una nueva época, se carga el último punto de control:

for epoch in range(epochs):
  if epoch > 0:
      tf.train.load_checkpoint(save_path)
  print(f"\nStart of epoch {epoch}")

  for step in range(steps_per_epoch):
    with tf.GradientTape() as tape:

      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)

      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))

    save_path = checkpoint_manager.save()
    print(f"Checkpoint saved to {save_path}")
    print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-1
Training loss at step 0: 2.4602103233337402
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-2
Training loss at step 1: 2.4579155445098877
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-3
Training loss at step 2: 2.4571962356567383
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-4
Training loss at step 3: 2.456108570098877
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-5
Training loss at step 4: 2.4541022777557373

Start of epoch 1
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-6
Training loss at step 0: 2.4518723487854004
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-7
Training loss at step 1: 2.451997995376587
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-8
Training loss at step 2: 2.450746774673462
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-9
Training loss at step 3: 2.4489808082580566
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-10
Training loss at step 4: 2.4467883110046387

Start of epoch 2
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-11
Training loss at step 0: 2.445439100265503
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-12
Training loss at step 1: 2.442873477935791
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-13
Training loss at step 2: 2.443373680114746
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-14
Training loss at step 3: 2.4398140907287598
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-15
Training loss at step 4: 2.4389309883117676

Start of epoch 3
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-16
Training loss at step 0: 2.437243938446045
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-17
Training loss at step 1: 2.4370715618133545
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-18
Training loss at step 2: 2.435986042022705
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-19
Training loss at step 3: 2.4329538345336914
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-20
Training loss at step 4: 2.431180953979492

Start of epoch 4
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-21
Training loss at step 0: 2.4317142963409424
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-22
Training loss at step 1: 2.43074631690979
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-23
Training loss at step 2: 2.428147077560425
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-24
Training loss at step 3: 2.4258265495300293
Checkpoint saved to /tmp/tmp0q79i7fx/ckpt-25
Training loss at step 4: 2.4255685806274414

Próximos pasos

Para obtener más información sobre la tolerancia a errores y los puntos de control en TensorFlow 2, considere la siguiente documentación:

También puede encontrar el siguiente material relacionado con el entrenamiento distribuido útil: