Groups checkpointable objects, saving and restoring them.
Checkpoint's constructor accepts keyword arguments whose values are types
that contain checkpointable state, such as
tf.keras.Layer implementations, or
tf.keras.Model implementations. It saves these values with a checkpoint, and
save_counter for numbering checkpoints.
Example usage when graph building:
import tensorflow as tf import os checkpoint_directory = "/tmp/training_checkpoints" checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) train_op = optimizer.minimize( ... ) status.assert_consumed() # Optional sanity checks. with tf.Session() as session: # Use the Session to restore variables, or initialize them if # tf.train.latest_checkpoint returned None. status.initialize_or_restore(session) for _ in range(num_training_steps): session.run(train_op) checkpoint.save(file_prefix=checkpoint_prefix)
Example usage with eager execution enabled:
import tensorflow as tf import os tf.enable_eager_execution() checkpoint_directory = "/tmp/training_checkpoints" checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) for _ in range(num_training_steps): optimizer.minimize( ... ) # Variables will be restored on creation. status.assert_consumed() # Optional sanity checks. checkpoint.save(file_prefix=checkpoint_prefix)
Checkpoint.restore write and read object-based
checkpoints, in contrast to
tf.train.Saver which writes and reads
variable.name based checkpoints. Object-based checkpointing saves a graph of
dependencies between Python objects (
etc.) with named edges, and this graph is used to match variables when
restoring a checkpoint. It can be more robust to changes in the Python
program, and helps to support restore-on-create for variables when executing
tf.train.Saver for new code.
Checkpoint objects have dependencies on the objects passed as keyword
arguments to their constructors, and each dependency is given a name that is
identical to the name of the keyword argument for which it was created.
TensorFlow classes like
Optimizers will automatically add
dependencies on their variables (e.g. "kernel" and "bias" for
tf.keras.layers.Dense). Inheriting from
tf.keras.Model makes managing
dependencies easy in user-defined classes, since
Model hooks into attribute
assignment. For example:
class Regress(tf.keras.Model): def __init__(self): super(Regress, self).__init__() self.input_transform = tf.keras.layers.Dense(10) # ... def call(self, inputs): x = self.input_transform(inputs) # ...
Model has a dependency named "input_transform" on its
which in turn depends on its variables. As a result, saving an instance of
tf.train.Checkpoint will also save all the variables created
save_counter: Incremented when
save()is called. Used to number checkpoints.
An integer variable which starts at zero and is incremented on save.
Used to number checkpoints.
The save counter variable.
Group objects into a training checkpoint.
**kwargs: Keyword arguments are set as attributes of this object, and are saved with the checkpoint. Values must be checkpointable objects.
ValueError: If objects in
kwargsare not checkpointable.
__setattr__( name, value )
Support self.foo = checkpointable syntax.
Restore a training checkpoint.
Checkpoint and any objects it depends on.
When executing eagerly, either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added after this call will be matched if they have a corresponding object in the checkpoint (the restore request will queue in any checkpointable object waiting for the expected dependency to be added).
When graph building, restoration ops are added to the graph but not run immediately.
To ensure that loading is complete and no more assignments will take place,
assert_consumed() method of the status object returned by
checkpoint = tf.train.Checkpoint( ... ) checkpoint.restore(path).assert_consumed()
An exception will be raised if any Python objects in the dependency graph were not found in the checkpoint, or if any checkpointed values do not have a matching Python object.
When graph building,
assert_consumed() indicates that all of the restore
ops that will be created for this checkpoint have been created. They can be
run via the
run_restore_ops() method of the status object:
If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph.
tf.train.Saver checkpoints can be loaded using this
method. Names are used to match variables. No restore ops are created/run
initialize_or_restore() are called on the
returned status object when graph building, but there is restore-on-creation
when executing eagerly. Re-encode name-based checkpoints using
tf.train.Checkpoint.save as soon as possible.
save_path: The path to the checkpoint, as returned by
tf.train.latest_checkpoint. If None (as when there is no latest checkpoint for
tf.train.latest_checkpointto return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based
tf.train.Saver, names are used to match variables.
A load status object, which can be used to make assertions about the status of a checkpoint restoration and run initialization/restore ops.
The returned status object has the following methods:
Raises an exception if any variables/objects are unmatched: either
checkpointed values which don't have a matching Python object or
Python objects in the dependency graph with no values in the
checkpoint. This method returns the status object, and so may be
When graph building, runs variable initializers if
None, but otherwise runs restore operations. If no
explicitly specified, the default session is used. No effect when
executing eagerly (variables are initialized or restored eagerly).
When graph building, runs restore operations. If no
explicitly specified, the default session is used. No effect when
executing eagerly (restore operations are run eagerly). May only be
save_path is not
save( file_prefix, session=None )
Save a training checkpoint.
The saved checkpoint includes variables created by this object and any
checkpointable objects it depends on at the time
file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this prefix and
session: The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used.
The full path to the checkpoint.