|TensorFlow 1 version||View source on GitHub|
Groups trackable objects, saving and restoring them.
tf.train.Checkpoint( **kwargs )
Used in the notebooks
|Used in the guide||Used in the tutorials|
Checkpoint's constructor accepts keyword arguments whose values are types
that contain trackable state, such as
tf.keras.Model implementations. It saves these values
with a checkpoint, and maintains a
save_counter for numbering checkpoints.
import tensorflow as tf import os checkpoint_directory = "/tmp/training_checkpoints" checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") # Create a Checkpoint that will manage two objects with trackable state, # one we name "optimizer" and the other we name "model". checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) for _ in range(num_training_steps): optimizer.minimize( ... ) # Variables will be restored on creation. status.assert_consumed() # Optional sanity checks. checkpoint.save(file_prefix=checkpoint_prefix)
Checkpoint.restore() write and read object-based
checkpoints, in contrast to TensorFlow 1.x's
variable.name based checkpoints. Object-based checkpointing saves a
graph of dependencies between Python objects (
Variables, etc.) with named edges, and this graph is used to match variables
when restoring a checkpoint. It can be more robust to changes in the Python
program, and helps to support restore-on-create for variables.
Checkpoint objects have dependencies on the objects passed as keyword
arguments to their constructors, and each dependency is given a name that is
identical to the name of the keyword argument for which it was created.
TensorFlow classes like