Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge


Manages saving/restoring trackable values to disk.

Used in the notebooks

Used in the guide Used in the tutorials

TensorFlow objects may contain trackable state, such as tf.Variables, tf.keras.optimizers.Optimizer implementations, iterators, tf.keras.Layer implementations, or tf.keras.Model implementations. These are called trackable objects.

A Checkpoint object can be constructed to save either a single or group of trackable objects to a checkpoint file. It maintains a save_counter for numbering checkpoints.


model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# is called, the save counter is increased.
save_path ='/tmp/training_checkpoints')

# Restore the checkpointed values to the `model` object.

Example 2:

import tensorflow as tf
import os

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks. and Checkpoint.restore() write and read object-based checkpoints, in contrast to TensorFlow 1.x's tf.compat.v1.train.Saver which writes and reads based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables.

Checkpoint objects have dependencies on the objects passed as keyword arguments to their constructors, and each dependency is given a name that is identical to the name of the keyword argument for which it was created. TensorFlow classes like Layers and Optimizers will automatically add dependencies on their own variables (e.g. "kernel" and "bias" for tf.keras.layers.Dense). Inheriting from tf.keras.Model makes managing dependencies easy in user-defined classes, since Model hooks into attribute assignment. For example:

class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)
    # ...

  def call(self, inputs):
    x = self.input_transform(inputs)
    # ...

This Model has a dependency named "input_transform" on its Dense layer, which in turn depends on its variables. As a result, saving an instance of Regress using tf.train.Checkpoint will also save all the variables created by the Dense layer.

When variables are assigned to multiple workers, each worker writes its own section of the checkpoint. These sections are then merged/re-indexed to behave as a single checkpoint. This avoids copying all variables to one worker, but does require that all workers see a common filesystem.

This function differs slightly from the Keras Model save_weights function. tf.keras.Model.save_weights creates a checkpoint file with the name specified in filepath, while tf.train.Checkpoint numbers the checkpoints, using filepath as the prefix for the checkpoint file names. Aside from this, model.save_weights() and tf.train.Checkpoint(model).save() are equivalent.

See the guide to training checkpoints for details.

root The root object to checkpoint.
**kwargs Keyword arguments are set as attributes of this object, and are saved with the checkpoint. Values must be trackable objects.

ValueError If root or the objects in kwargs are not trackable. A ValueError is also raised if the root object tracks different objects from the ones listed in attributes in kwargs (e.g. root.child = A and tf.train.Checkpoint(root, child=B) are incompatible).

save_counter Incremented when