|TensorFlow 1 version||View source on GitHub|
Groups trackable objects, saving and restoring them.
tf.train.Checkpoint( **kwargs )
Used in the notebooks
|Used in the guide||Used in the tutorials|
Checkpoint's constructor accepts keyword arguments whose values are types
that contain trackable state, such as
tf.keras.Model implementations. It saves these values
with a checkpoint, and maintains a
save_counter for numbering checkpoints.
import tensorflow as tf import os checkpoint_directory = "/tmp/training_checkpoints" checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") # Create a Checkpoint that will manage two objects with trackable state, # one we name "optimizer" and the other we name "model". checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) for _ in range(num_training_steps): optimizer.minimize( ... ) # Variables will be restored on creation. status.assert_consumed() # Optional sanity checks. checkpoint.save(file_prefix=checkpoint_prefix)
Checkpoint.restore() write and read object-based
checkpoints, in contrast to TensorFlow 1.x's
variable.name based checkpoints. Object-based checkpointing saves a
graph of dependencies between Python objects (
Variables, etc.) with named edges, and this graph is used to match variables
when restoring a checkpoint. It can be more robust to changes in the Python
program, and helps to support restore-on-create for variables.
Checkpoint objects have dependencies on the objects passed as keyword
arguments to their constructors, and each dependency is given a name that is
identical to the name of the keyword argument for which it was created.
TensorFlow classes like
Optimizers will automatically add
dependencies on their own variables (e.g. "kernel" and "bias" for
tf.keras.layers.Dense). Inheriting from
tf.keras.Model makes managing
dependencies easy in user-defined classes, since
Model hooks into attribute
assignment. For example:
class Regress(tf.keras.Model): def __init__(self): super(Regress, self).__init__() self.input_transform = tf.keras.layers.Dense(10) # ... def call(self, inputs): x = self.input_transform(inputs) # ...
Model has a dependency named "input_transform" on its
which in turn depends on its variables. As a result, saving an instance of
tf.train.Checkpoint will also save all the variables created
When variables are assigned to multiple workers, each worker writes its own section of the checkpoint. These sections are then merged/re-indexed to behave as a single checkpoint. This avoids copying all variables to one worker, but does require that all workers see a common filesystem.
tf.train.Checkpoint.save save in the
same format, note that the root of the resulting checkpoint is the object the
save method is attached to. This means saving a
save_weights and loading into a
tf.train.Checkpoint with a
attached (or vice versa) will not match the
Model's variables. See the
guide to training
||Keyword arguments are set as attributes of this object, and are saved with the checkpoint. Values must be trackable objects.|
If objects in
read( save_path, options=None )
Read a training checkpoint written with
Checkpoint and any objects it depends on.
This method is just like
restore() but does not expect the
variable in the checkpoint. It only restores the objects that the checkpoint
already depends on.
The method is primarily intended for use by higher level checkpoint
management utilities that use
write() instead of
save() and have their
own mechanisms to number and track checkpoints.
# Create a checkpoint with write() ckpt = tf.train.Checkpoint(v=tf.Variable(1.)) path = ckpt.write('/tmp/my_checkpoint') # Later, load the checkpoint with read() # With restore() assert_consumed() would have failed. checkpoint.read(path).assert_consumed() # You can also pass options to restore(). For example this # runs the IO ops on the localhost: options = tf.CheckpointOptions(experimental_io_device="/job:localhost") checkpoint.read(path, options=options)
The path to the checkpoint as returned by
A load status object, which can be used to make assertions about the
status of a checkpoint restoration. See
restore( save_path, options=None )
Restore a training checkpoint.
Checkpoint and any objects it depends on.
This method is intended to be used to load checkpoints created by
For checkpoints created by
write() use the
read() method which does not
save_counter variable added by
restore() either assigns values immediately if variables to restore have
been created already, or defers restoration until the variables are
created. Dependencies added after this call will be matched if they have a
corresponding object in the checkpoint (the restore request will queue in
any trackable object waiting for the expected dependency to be added).
To ensure that loading is complete and no more assignments will take place,
assert_consumed() method of the status object returned by
checkpoint = tf.train.Checkpoint( ... ) checkpoint.restore(path).assert_consumed() # You can additionally pass options to restore(): options = tf.CheckpointOptions(experimental_io_device="/job:localhost") checkpoint.restore(path, options=options).assert_consumed()
An exception will be raised if any Python objects in the dependency graph were not found in the checkpoint, or if any checkpointed values do not have a matching Python object.
tf.compat.v1.train.Saver checkpoints from TensorFlow 1.x can be
using this method. Names are used to match variables. Re-encode name-based
tf.train.Checkpoint.save as soon as possible.
The path to the checkpoint, as returned by