Attend the Women in ML Symposium on December 7 Register now

Migrate checkpoint saving

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Continually saving the "best" model or model weights/parameters has many benefits. These include being able to track the training progress and load saved models from different saved states.

In TensorFlow 1, to configure checkpoint saving during training/validation with the tf.estimator.Estimator APIs, you specify a schedule in tf.estimator.RunConfig or use tf.estimator.CheckpointSaverHook. This guide demonstrates how to migrate from this workflow to TensorFlow 2 Keras APIs.

In TensorFlow 2, you can configure tf.keras.callbacks.ModelCheckpoint in a number of ways:

  • Save the "best" version according to a metric monitored using the save_best_only=True parameter, where monitor can be, for example, 'loss', 'val_loss', 'accuracy', or'val_accuracy'`.
  • Save continually at a certain frequency (using the save_freq argument).
  • Save the weights/parameters only instead of the whole model by setting save_weights_only to True.

For more details, refer to the tf.keras.callbacks.ModelCheckpoint API docs and the Save checkpoints during training section in the Save and load models tutorial. Learn more about the Checkpoint format in the TF Checkpoint format section in the Save and load Keras models guide. In addition, to add fault tolerance, you can use tf.keras.callbacks.BackupAndRestore or tf.train.Checkpoint for manual checkpointing. Learn more in the Fault tolerance migration guide.

Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit/Model.evaluate/Model.predict APIs. Learn more in the Next steps section at the end of the guide.

Setup

Start with imports and a simple dataset for demonstration purposes:

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step

TensorFlow 1: Save checkpoints with tf.estimator APIs

This TensorFlow 1 example shows how to configure tf.estimator.RunConfig to save checkpoints at every step during training/evaluation with the tf.estimator.Estimator APIs:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]

config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)

test_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_test},
    y=y_test.astype(np.int32),
    num_epochs=10,
    shuffle=False
)

train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
                                   steps=10,
                                   throttle_secs=0)

tf1.estimator.train_and_evaluate(estimator=classifier,
                                train_spec=train_spec,
                                eval_spec=eval_spec)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmplrkjo9in', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmp/ipykernel_20296/3980459272.py:18: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmp/ipykernel_20296/3980459272.py:18: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 1 or save_checkpoints_secs None.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:47
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-1
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.26374s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:47
INFO:tensorflow:Saving dict for global step 1: accuracy = 0.1765625, average_loss = 2.2546134, global_step = 1, loss = 288.5905
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1: /tmp/tmplrkjo9in/model.ckpt-1
INFO:tensorflow:loss = 118.3231, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:48
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-2
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.36662s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:48
INFO:tensorflow:Saving dict for global step 2: accuracy = 0.2859375, average_loss = 2.1868849, global_step = 2, loss = 279.92126
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2: /tmp/tmplrkjo9in/model.ckpt-2
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:48
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22792s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:48
INFO:tensorflow:Saving dict for global step 3: accuracy = 0.35078126, average_loss = 2.1220195, global_step = 3, loss = 271.6185
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmp/tmplrkjo9in/model.ckpt-3
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:49
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-4
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22387s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:49
INFO:tensorflow:Saving dict for global step 4: accuracy = 0.40234375, average_loss = 2.0655982, global_step = 4, loss = 264.39658
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 4: /tmp/tmplrkjo9in/model.ckpt-4
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmplrkjo9in/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:49
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-5
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22548s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:49
INFO:tensorflow:Saving dict for global step 5: accuracy = 0.42421874, average_loss = 2.0072064, global_step = 5, loss = 256.92242
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5: /tmp/tmplrkjo9in/model.ckpt-5
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:50
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-6
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22806s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:50
INFO:tensorflow:Saving dict for global step 6: accuracy = 0.43984374, average_loss = 1.9473753, global_step = 6, loss = 249.26404
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6: /tmp/tmplrkjo9in/model.ckpt-6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:50
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-7
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.23091s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:50
INFO:tensorflow:Saving dict for global step 7: accuracy = 0.44296876, average_loss = 1.8903366, global_step = 7, loss = 241.96309
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7: /tmp/tmplrkjo9in/model.ckpt-7
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:51
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-8
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22453s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:51
INFO:tensorflow:Saving dict for global step 8: accuracy = 0.44453126, average_loss = 1.8294731, global_step = 8, loss = 234.17256
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8: /tmp/tmplrkjo9in/model.ckpt-8
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:51
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-9
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22271s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:51
INFO:tensorflow:Saving dict for global step 9: accuracy = 0.47734374, average_loss = 1.7674354, global_step = 9, loss = 226.23174
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9: /tmp/tmplrkjo9in/model.ckpt-9
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmplrkjo9in/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:52
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.38483s
INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:52
INFO:tensorflow:Saving dict for global step 10: accuracy = 0.5140625, average_loss = 1.7108486, global_step = 10, loss = 218.98862
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmplrkjo9in/model.ckpt-10
INFO:tensorflow:Loss for final step: 96.2236.
({'accuracy': 0.5140625,
  'average_loss': 1.7108486,
  'loss': 218.98862,
  'global_step': 10},
 [])
%ls {classifier.model_dir}
checkpoint
eval/
events.out.tfevents.1642127326.kokoro-gcp-ubuntu-prod-837339153
graph.pbtxt
model.ckpt-10.data-00000-of-00001
model.ckpt-10.index
model.ckpt-10.meta
model.ckpt-6.data-00000-of-00001
model.ckpt-6.index
model.ckpt-6.meta
model.ckpt-7.data-00000-of-00001
model.ckpt-7.index
model.ckpt-7.meta
model.ckpt-8.data-00000-of-00001
model.ckpt-8.index
model.ckpt-8.meta
model.ckpt-9.data-00000-of-00001
model.ckpt-9.index
model.ckpt-9.meta

TensorFlow 2: Save checkpoints with a Keras callback for Model.fit

In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate) for training/evaluation, you can configure tf.keras.callbacks.ModelCheckpoint and then pass it to the callbacks parameter of Model.fit (or Model.evaluate). (Learn more in the API docs and the Using callbacks section in the Training and evaluation with the built-in methods guide.)

In the example below, you will use a tf.keras.callbacks.ModelCheckpoint callback to store checkpoints in a temporary directory:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])

model = create_model()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=log_dir)

model.fit(x=x_train,
          y=y_train,
          epochs=10,
          validation_data=(x_test, y_test),
          callbacks=[model_checkpoint_callback])
Epoch 1/10
1840/1875 [============================>.] - ETA: 0s - loss: 0.2224 - accuracy: 0.9348
2022-01-14 02:28:56.714889: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2208 - accuracy: 0.9354 - val_loss: 0.1132 - val_accuracy: 0.9669
Epoch 2/10
1870/1875 [============================>.] - ETA: 0s - loss: 0.0961 - accuracy: 0.9706INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0962 - accuracy: 0.9706 - val_loss: 0.0784 - val_accuracy: 0.9753
Epoch 3/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0696 - accuracy: 0.9781INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0695 - accuracy: 0.9782 - val_loss: 0.0684 - val_accuracy: 0.9788
Epoch 4/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0529 - accuracy: 0.9826INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0531 - accuracy: 0.9826 - val_loss: 0.0671 - val_accuracy: 0.9791
Epoch 5/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0423 - accuracy: 0.9860INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0424 - accuracy: 0.9860 - val_loss: 0.0772 - val_accuracy: 0.9757
Epoch 6/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0345 - accuracy: 0.9888INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0345 - accuracy: 0.9888 - val_loss: 0.0669 - val_accuracy: 0.9811
Epoch 7/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0314 - accuracy: 0.9895INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0313 - accuracy: 0.9895 - val_loss: 0.0718 - val_accuracy: 0.9800
Epoch 8/10
1870/1875 [============================>.] - ETA: 0s - loss: 0.0298 - accuracy: 0.9899INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0298 - accuracy: 0.9899 - val_loss: 0.0632 - val_accuracy: 0.9825
Epoch 9/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0230 - accuracy: 0.9925INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0231 - accuracy: 0.9924 - val_loss: 0.0748 - val_accuracy: 0.9800
Epoch 10/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0220 - accuracy: 0.9920INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0222 - accuracy: 0.9920 - val_loss: 0.0703 - val_accuracy: 0.9825
<keras.callbacks.History at 0x7f638c204410>
%ls {model_checkpoint_callback.filepath}
assets/  keras_metadata.pb  saved_model.pb  variables/

Next steps

Learn more about checkpointing in:

Learn more about callbacks in:

You may also find the following migration-related resources useful: