tf.keras.callbacks.ModelCheckpoint

Class ModelCheckpoint

Save the model after every epoch.

Inherits From: Callback

Aliases:

  • Class tf.compat.v1.keras.callbacks.ModelCheckpoint
  • Class tf.compat.v2.keras.callbacks.ModelCheckpoint
  • Class tf.keras.callbacks.ModelCheckpoint

Defined in python/keras/callbacks.py.

filepath can contain named formatting options, which will be filled the value of epoch and keys in logs (passed in on_epoch_end).

For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}.hdf5, then the model checkpoints will be saved with the epoch number and the validation loss in the filename.

Arguments:

  • filepath: string, path to save the model file.
  • monitor: quantity to monitor.
  • verbose: verbosity mode, 0 or 1.
  • save_best_only: if save_best_only=True, the latest best model according to the quantity monitored will not be overwritten.
  • mode: one of {auto, min, max}. If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
  • save_weights_only: if True, then only the model's weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).
  • save_freq: 'epoch' or integer. When using 'epoch', the callback saves the model after each epoch. When using integer, the callback saves the model at end of a batch at which this many samples have been seen since last saving. Note that if the saving isn't aligned to epochs, the monitored metric may potentially be less reliable (it could reflect as little as 1 batch, since the metrics get reset every epoch). Defaults to 'epoch'
  • load_weights_on_restart: Whether the training should restore the model. If True, the model will attempt to load the checkpoint file from filepath at the start of model.fit(). This saves the need of manually calling model.load_weights() before `model.fit(). In multi-worker distributed training, this provides fault-tolerance and loads the model automatically upon recovery of workers. The callback gives up loading if the filepath does not exist, and raises ValueError if format does not match. Defaults to False.
  • **kwargs: Additional arguments for backwards compatibility. Possible key is period.

__init__

__init__(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    load_weights_on_restart=False,
    **kwargs
)

Methods

on_batch_begin

on_batch_begin(
    batch,
    logs=None
)

A backwards compatibility alias for on_train_batch_begin.

on_batch_end

on_batch_end(
    batch,
    logs=None
)

on_epoch_begin

on_epoch_begin(
    epoch,
    logs=None
)

on_epoch_end

on_epoch_end(
    epoch,
    logs=None
)

on_predict_batch_begin

on_predict_batch_begin(
    batch,
    logs=None
)

Called at the beginning of a batch in predict methods.

Subclasses should override for any actions to run.

Arguments:

  • batch: integer, index of batch within the current epoch.
  • logs: dict. Has keys batch and size representing the current batch number and the size of the batch.

on_predict_batch_end

on_predict_batch_end(
    batch,
    logs=None
)

Called at the end of a batch in predict methods.

Subclasses should override for any actions to run.

Arguments:

  • batch: integer, index of batch within the current epoch.
  • logs: dict. Metric results for this batch.

on_predict_begin

on_predict_begin(logs=None)

Called at the beginning of prediction.

Subclasses should override for any actions to run.

Arguments:

  • logs: dict. Currently no data is passed to this argument for this method but that may change in the future.

on_predict_end

on_predict_end(logs=None)

Called at the end of prediction.

Subclasses should override for any actions to run.

Arguments:

  • logs: dict. Currently no data is passed to this argument for this method but that may change in the future.

on_test_batch_begin

on_test_batch_begin(
    batch,
    logs=None
)

Called at the beginning of a batch in evaluate methods.

Also called at the beginning of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Arguments:

  • batch: integer, index of batch within the current epoch.
  • logs: dict. Has keys batch and size representing the current batch number and the size of the batch.

on_test_batch_end

on_test_batch_end(
    batch,
    logs=None
)

Called at the end of a batch in evaluate methods.

Also called at the end of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Arguments:

  • batch: integer, index of batch within the current epoch.
  • logs: dict. Metric results for this batch.

on_test_begin

on_test_begin(logs=None)

Called at the beginning of evaluation or validation.

Subclasses should override for any actions to run.

Arguments:

  • logs: dict. Currently no data is passed to this argument for this method but that may change in the future.

on_test_end

on_test_end(logs=None)

Called at the end of evaluation or validation.

Subclasses should override for any actions to run.

Arguments:

  • logs: dict. Currently no data is passed to this argument for this method but that may change in the future.

on_train_batch_begin

on_train_batch_begin(
    batch,
    logs=None
)

Called at the beginning of a training batch in fit methods.

Subclasses should override for any actions to run.

Arguments:

  • batch: integer, index of batch within the current epoch.
  • logs: dict. Has keys batch and size representing the current batch number and the size of the batch.

on_train_batch_end

on_train_batch_end(
    batch,
    logs=None
)

Called at the end of a training batch in fit methods.

Subclasses should override for any actions to run.

Arguments:

  • batch: integer, index of batch within the current epoch.
  • logs: dict. Metric results for this batch.

on_train_begin

on_train_begin(logs=None)

on_train_end

on_train_end(logs=None)

set_model

set_model(model)

set_params

set_params(params)