![]() |
Built-in model blowup recovery module.
tfm.core.base_trainer.Recovery(
loss_upper_bound: float,
checkpoint_manager: tf.train.CheckpointManager,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3
)
Checks the loss value by the given threshold. If applicable, recover the model by reading the checkpoint on disk.
Methods
maybe_recover
maybe_recover(
loss_value, global_step
)
Conditionally recovers the training by triggering checkpoint restoration.
Args | |
---|---|
loss_value
|
the loss value as a float. |
global_step
|
the number of global training steps. |
Raises | |
---|---|
RuntimeError
|
when recovery happens more than the max number of trials, the job should crash. |
should_recover
should_recover(
loss_value, global_step
)