Groups trackable objects, saving and restoring them.

Used in the notebooks

Used in the guide Used in the tutorials

Checkpoint's constructor accepts keyword arguments whose values are types that contain trackable state, such as tf.keras.optimizers.Optimizer implementations, tf.Variables, iterators, tf.keras.Layer implementations, or tf.keras.Model implementations. It saves these values with a checkpoint, and maintains a save_counter for numbering checkpoints.

Example usage:

import tensorflow as tf
import os

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks. and Checkpoint.restore() write and read object-based checkpoints, in contrast to TensorFlow 1.x's tf.compat.v1.train.Saver which writes and reads based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables.

Checkpoint objects have dependencies on the objects passed as keyword arguments to their constructors, and each dependency is given a name that is identical to the name of the keyword argument for which it was created. TensorFlow classes like Layers and Optimizers will automatically add dependencies on their own variables (e.g. "kernel" and "bias" for tf.keras.layers.Dense). Inheriting from tf.keras.Model makes managing dependencies easy in user-defined classes, since Model hooks into attribute assignment. For example:

class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)
    # ...

  def call(self, inputs):
    x = self.input_transform(inputs)
    # ...

This Model has a dependency named "input_transform" on its Dense layer, which in turn depends on its variables. As a result, saving an instance of Regress using tf.train.Checkpoint will also save all the variables created by the Dense layer.

When variables are assigned to multiple workers, each worker writes its own section of the checkpoint. These sections are then merged/re-indexed to behave as a single checkpoint. This avoids copying all variables to one worker, but does require that all workers see a common filesystem.

While tf.keras.Model.save_weights and save in the same format, note that the root of the resulting checkpoint is the object the save method is attached to. This means saving a tf.keras.Model using save_weights and loading into a tf.train.Checkpoint with a Model attached (or vice versa) will not match the Model's variables. See the guide to training checkpoints for details. Prefer tf.train.Checkpoint over tf.keras.Model.save_weights for training checkpoints.

**kwargs Keyword arguments are set as attributes of this object, and are saved with the checkpoint. Values must be trackable objects.

ValueError If objects in kwargs are not trackable.

save_counter Incremented when save() is called. Used to number checkpoints.



View source

Read a training checkpoint written with write.

Reads this Checkpoint and any objects it depends on.

This method is just like restore() but does not expect the save_counter variable in the checkpoint. It only restores the objects that the checkpoint already depends on.

The method is primarily intended for use by higher level checkpoint management utilities that use write() instead of save() and have their own mechanisms to number and track checkpoints.

Example usage:

# Create a checkpoint with write()
ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
path = ckpt.write('/tmp/my_checkpoint')

# Later, load the checkpoint with read()
# With restore() assert_consumed() would have failed.

# You can also pass options to restore(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost"), options=options)

save_path The path to the checkpoint as returned by write.
options Optional tf.train.CheckpointOptions object.

A load status object, which can be used to make assertions about the status of a checkpoint restoration. See restore for details.


View source

Restore a training checkpoint.

Restores this Checkpoint and any objects it depends on.

This method is intended to be used to load checkpoints created by save(). For checkpoints created by write() use the read() method which does not expect the save_counter variable added by save().

restore() either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added after this call will be matched if they have a corresponding object in the checkpoint (the restore request will queue in any trackable object waiting for the expected dependency to be added).

To ensure that loading is complete and no more assignments will take place, use the assert_consumed() method of the status object returned by restore():

checkpoint = tf.train.Checkpoint( ... )

# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options).assert_consumed()

An exception will be raised if any Python objects in the dependency graph were not found in the checkpoint, or if any checkpointed values do not have a matching Python object.

Name-based tf.compat.v1.train.Saver checkpoints from TensorFlow 1.x can be loaded using this method. Names are used to match variables. Re-encode name-based checkpoints using as soon as possible.

save_path The path to the checkpoint, as returned by save or tf.train.latest_checkpoint. If the checkpoint was written by the name-based tf.compat.v1.train.Saver, names are used to match variables.
options Optional tf.train.CheckpointOptions object.

A load status object, which can be used to make assertions about the status of a checkpoint restoration.

The returned status object has the following methods:

  • assert_consumed(): Raises an exception if any variables are unmatched: either checkpointed values which don't have a matching Python object or Python objects in the dependency graph with no values in the checkpoint. This method returns the status object, and so may be chained with other assertions.

  • assert_existing_objects_matched(): Raises an exception if any existing Python objects in the dependency graph are unmatched. Unlike assert_consumed, this assertion will pass if values in the checkpoint have no corresponding Python objects. For example a tf.keras.Layer object which has not yet been built, and so has not created any variables, will pass this assertion but fail