tf.distribute.experimental.PreemptionCheckpointHandler

Preemption and error handler for synchronous training.

A PreemptionCheckpointHandler coordinates all workers to save a checkpoint upon receiving a preemption signal. It also helps disseminate application error messages accurately among the cluster. When a PreemptionCheckpointHandler object is created, it restores values from the latest checkpoint file if any exists.

Right after the initialization, a thread starts to watch out for a termination signal for any member in the cluster. If receiving a signal, the next time the worker enters a PreemptionCheckpointHandler.run call, the PreemptionCheckpointHandler will align the worker steps to save a checkpoint and maybe exit -- depending on the exit_fn in tf.distribute.experimental.TerminationConfig.

Example usage:

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
  dataset, model, optimizer = ...

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

  preemption_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint, checkpoint_directory)

  # preemption_handler.total_run_calls will be restored to its saved value if
  # training is restored after interruption.
  for epoch in range(preemption_handler.total_run_calls // STEPS_PER_EPOCH, num_epochs):
    for step in range(preemption_handler.total_run_calls % STEPS_PER_EPOCH, STEPS_PER_EPOCH):
      # distributed_train_step is a single-step training function wrapped by tf.distribute.Strategy.run.
      loss += preemption_handler.run(distributed_train_step, args=(next(dataset),))

Not all interruptions come with advance notice so that the PreemptionCheckpointHandler can handle them, e.g., those caused by hardware failure. For a user who saves checkpoints for these cases themselves outside the PreemptionCheckpointHandler, if they are using a tf.train.CheckpointManager, pass it as the checkpoint_or_checkpoint_manager argument to the PreemptionCheckpointHandler. If they do not have a tf.train.CheckpointManager but are directly working with tf.train.Checkpoint, we advise saving the checkpoints in the directory that's passed as the checkpoint_dir argument. In this way, at the program beginning, PreemptionCheckpointHandler can restore the latest checkpoint from the directory, no matter it's saved by the user themselves or saved by the PreemptionCheckpointHandler before preemption happens.

If a user cannot infer the start epoch and start step from PreemptionCheckpointHandler.total_run_calls (e.g., if there is no preknown STEPS_PER_EPOCH or if their STEPS_PER_EPOCH may vary from epoch to epoch), we recommend tracking the epoch and step numbers themselves and save them in the passed-in checkpoint:

strategy = tf.distribute.MultiWorkerMirroredStrategy()

trained_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='step_in_epoch')

with strategy.scope():
  dataset, model, optimizer = ...

  checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                   model=model,
                                   trained_epoch=trained_epoch,
                                   step_in_epoch=step_in_epoch)

  preemption_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint, checkpoint_dir)

while trained_epoch.numpy() < NUM_EPOCH:

  while step_in_epoch.numpy() < STEPS_PER_EPOCH:

    loss += failure_handler.run(train_step, args=(next(iterator),))
    step_in_epoch.assign_add(1)
    ...

  epoch.assign_add(1)
  step_in_epoch.assign(0)

A note on the platform:

PreemptionCheckpointHandler can only handle the kind of termination with advance notice. For now, the API recognizes the Google Borg and the Google Cloud Platform, where it can automatically adopt the correct preemption/maintenance notification detection mechanism. Users of other platforms can configure it through a tf.distribute.experimental.TerminationConfig. Customization for the exit behavior and grace period length could also be done here.

cluster_resolver a tf.distribute.cluster_resolver.ClusterResolver object. You may also obtain it through the cluster_resolver attribute of the distribution strategy in use.
checkpoint_or_checkpoint_manager a tf.train.CheckpointManager or a tf.train.Checkpoint. If you are using a tf.train.CheckpointManager to manage checkpoints outside the PreemptionCheckpointHandler for backup purpose as well, pass it as checkpoint_or_checkpoint_manager argument. Otherwise, pass a tf.train.Checkpoint and the PreemptionCheckpointHandler will create a tf.train.CheckpointManager to manage it in the checkpoint_dir.
checkpoint_dir a directory where the PreemptionCheckpointHandler saves and restores checkpoints. When a PreemptionCheckpointHandler is created, the latest checkpoint in the checkpoint_dir will be restored. (This is not needed if a tf.train.CheckpointManager instead of a tf.train.Checkpoint is passed as the checkpoint_or_checkpoint_manager argument.)
termination_config optional, a tf.distribute.experimental.TerminationConfig object to configure for a platform other than Google Borg or GCP.

total_run_calls Returns the number of times PreemptionCheckpointHandler.run is called.

This value tracks the number of all calls to PreemptionCheckpointHandler.run including those before the program is restarted and the training is restored, by saving and reading the value in the checkpoint. A user can compute their total number of iterations by PreemptionCheckpointHandler.total_run_calls * number_of_steps_in_train_function, while number_of_steps_in_train_function should be one for tf.distribute.MultiWorkerMirroredStrategy users. They can also use this value to infer the starting epoch and step after training restores, as shown in the example above.

Methods

run

View source

Runs a training function with error and preemption handling.

This function handles the preemption signal from any peer in the cluster by saving the training progress and exiting gracefully. It will also broadcase any program error encountered during the execution of distributed_train_function to all workers so that they can raise the same error.

The distributed_train_function argument should be a distributed train function (i.e., containing a call to tf.distribute.Strategy.run). For tf.distribute.MultiWorkerMirroredStrategy users, we recommend passing in a single-step distributed_train_function to PreemptionCheckpointHandler.run so that the checkpoint can be saved in time in case a preemption signal or maintenance notice is sent.

Besides the preemption and error handling part, PreemptionCheckpointHandler.run(distributed_train_function, *args, **kwargs) has the same effect and output as distributed_train_function(*args, **kwargs). distributed_train_function can return either some or no result. The following is a shortened example:


@tf.function
def distributed_train_step(iterator):
  # A distributed single-step training function.

  def step_fn(inputs):
    # A per-replica single-step training function.
    x, y = inputs
    ...
    return loss

  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
  return strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

for epoch in range(preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                   EPOCHS_TO_RUN):
  iterator = iter(multi_worker_dataset)
  total_loss = 0.0
  num_batches = 0

  for step in range(preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                    STEPS_PER_EPOCH):
    total_loss += preemption_handler.run(distributed_train_step)
    num_batches += 1

  train_loss = total_loss / num_batches
  print('Epoch: %d, train_loss: %f.' %(epoch.numpy(), train_loss))

  train_accuracy.reset_states()

Args
distributed_train_function A (single-step) distributed training function.
*args args for distributed_train_function.
**kwargs kwargs for distributed_train_function.

Raises
Program error encountered by any member in the cluster while executing the distributed_train_function, or any error from the program error propagation process.

Returns
Result of running the distributed_train_function.