Migrating model checkpoints

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Overview

This guide assumes that you have a model that saves and loads checkpoints with tf.compat.v1.Saver, and want to migrate the code use the TF2 tf.train.Checkpoint API, or use pre-existing checkpoints in your TF2 model.

Below are some common scenarios that you may encounter:

Scenario 1

There are existing TF1 checkpoints from previous training runs that need to be loaded or converted to TF2.

Scenario 2

You are adjusting your model in a way that risks changing variable names and paths (such as when incrementally migrating away from get_variable to explicit tf.Variable creation), and would like to maintain saving/loading of existing checkpoints along the way.

See the section on How to maintain checkpoint compatibility during model migration

Scenario 3

You are migrating your training code and checkpoints to TF2, but your inference pipeline continues to require TF1 checkpoints for now (for production stability).

Option 1

Save both TF1 and TF2 checkpoints when training.

Option 2

Convert the TF2 checkpoint to TF1.


The examples below show all the combinations of saving and loading checkpoints in TF1/TF2, so you have some flexibility in determining how to migrate your model.

Setup

import tensorflow as tf
import tensorflow.compat.v1 as tf1

def print_checkpoint(save_path):
  reader = tf.train.load_checkpoint(save_path)
  shapes = reader.get_variable_to_shape_map()
  dtypes = reader.get_variable_to_dtype_map()
  print(f"Checkpoint at '{save_path}':")
  for key in shapes:
    print(f"  (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, "
          f"value={reader.get_tensor(key)})")

Changes from TF1 to TF2

This section is included if you are curious about what has changed between TF1 and TF2, and what we mean by "name-based" (TF1) vs "object-based" (TF2) checkpoints.

The two types of checkpoints are actually saved in the same format, which is essentially a key-value table. The difference lies in how the keys are generated.

The keys in named-based checkpoints are the names of the variables. The keys in object-based checkpoints refer to the path from the root object to the variable (the examples below will help to get a better sense of what this means).

First, save some checkpoints:

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    saver.save(sess, 'tf1-ckpt')

print_checkpoint('tf1-ckpt')
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(7.0, name='c')

ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save('tf2-ckpt')
print_checkpoint(save_path_v2)

If you look at the keys in tf2-ckpt, they all refer to the object paths of each variable. For example, variable a is the first element in the variables list, so its key becomes variables/0/... (feel free to ignore the .ATTRIBUTES/VARIABLE_VALUE constant).

A closer inspection of the Checkpoint object below:

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])

Try experimenting with the below snippet and see how the checkpoint keys change with the object structure:

module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, 
                                c=c,
                                module=module)
test_ckpt_path = test_ckpt.save('root-tf2-ckpt')
print_checkpoint(test_ckpt_path)

Why does TF2 use this mechanism?

Because there is no more global graph in TF2, variable names are unreliable and can be inconsistent between programs. TF2 encourages the object-oriented modelling approach where variables are owned by layers, and layers are owned by a model:

variable = tf.Variable(...)
layer.variable_name = variable
model.layer_name = layer

How to maintain checkpoint compatibility during model migration

One important step in the migration process is ensuring that all variables are initialized to the correct values, which in turn allows you to validate that the ops/functions are doing the correct computations. To accomplish this, you must consider the checkpoint compatibility between models in the various stages of migration. Essentially, this section answers the question, how do I keep using the same checkpoint while changing the model.

Below are three ways of maintaining checkpoint compatibility, in order of increasing flexibility:

  1. The model has the same variable names as before.
  2. The model has different variable names, and maintains a assignment map that maps variable names in the checkpoint to the new names.
  3. The model has different variable names, and maintains a TF2 Checkpoint object that stores all of the variables.

When the variable names match

Long title: How to re-use checkpoints when the variable names match.

Short answer: You can directly load the pre-existing checkpoint with either tf1.train.Saver or tf.train.Checkpoint.


If you are using tf.compat.v1.keras.utils.track_tf1_style_variables, then it will ensure that your model variable names are the same as before. You can also manually ensure that variable names match.

When the variable names match in the migrated models, you may directly use either tf.train.Checkpoint or tf.compat.v1.train.Saver to load the checkpoint. Both APIs are compatible with eager and graph mode, so you can use them at any stage of the migration.

Below are examples of using the same checkpoint with different models. First, save a TF1 checkpoint with tf1.train.Saver:

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    save_path = saver.save(sess, 'tf1-ckpt')
print_checkpoint(save_path)

The example below uses tf.compat.v1.Saver to load the checkpoint while in eager mode:

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# With the removal of collections in TF2, you must pass in the list of variables
# to the Saver object:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=save_path)
print(f"loaded values of [a, b, c]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}]")

# Saving also works in eager (sess must be None).
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)

The next snippet loads the checkpoint using the TF2 API tf.train.Checkpoint:

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0.0, name='scoped/c')

print("Variable names: ")
print(f"  a.name = {a.name}")
print(f"  b.name = {b.name}")
print(f"  c.name = {c.name}")
print(f"  c_2.name = {c_2.name}")

# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(save_path)
print(f"loaded values of [a, b, c, c_2]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")

Variable names in TF2

  • Variables still all have a name argument you can set.
  • Keras models also take a name argument as which they set as the prefix for their variables.
  • The v1.name_scope function can be used to set variable name prefixes. This is very different from tf.variable_scope. It only affects names, and doesn't track variables and reuse.

The tf.compat.v1.keras.utils.track_tf1_style_variables decorator is a shim that helps you maintain variable names and TF1 checkpoint compatibility, by keeping the naming and reuse semantics of tf.variable_scope and tf.compat.v1.get_variable unchanged. See the Model mapping guide for more info.

Note 1: If you are using the shim, use TF2 APIs to load your checkpoints (even when using pre-trained TF1 checkpoints).

See the section Checkpoint Keras.

Note 2: When migrating to tf.Variable from get_variable:

If your shim-decorated layer or module consists of some variables (or Keras layers/models) that use tf.Variable instead of tf.compat.v1.get_variable and get attached as properties/tracked in an object oriented way, they may have different variable naming semantics in TF1.x graphs/sessions versus during eager execution.

In short, the names may not be what you expect them to be when running in TF2.

Maintaining assignment maps

Assignment maps are commonly used to transfer weights between TF1 models, and can also be used during your model migration if the variable names change.

You can use these maps with tf.compat.v1.train.init_from_checkpoint, tf.compat.v1.train.Saver, and tf.train.load_checkpoint to load weights into models in which the variable or scope names may have changed.

The examples in this section will use a previously saved checkpoint:

print_checkpoint('tf1-ckpt')

Loading with init_from_checkpoint

tf1.train.init_from_checkpoint must be called while in a Graph/Session, because it places the values in the variable initializers instead of creating an assign op.

You can use the assignment_map argument to configure how the variables are loaded. From the documentation:

Assignment map supports following syntax:

  • 'checkpoint_scope_name/': 'scope_name/' - will load all variables in current scope_name from checkpoint_scope_name with matching tensor names.
  • 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - will initialize scope_name/variable_name variable from checkpoint_scope_name/some_other_variable.
  • 'scope_variable_name': variable - will initialize given tf.Variable object with tensor 'scope_variable_name' from the checkpoint.
  • 'scope_variable_name': list(variable) - will initialize list of partitioned variables with tensor 'scope_variable_name' from the checkpoint.
  • '/': 'scope_name/' - will load all variables in current scope_name from checkpoint's root (e.g. no scope).
# Restoring with tf1.train.init_from_checkpoint:

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # The assignment map will remap all variables in the checkpoint to the
    # new scope:
    tf1.train.init_from_checkpoint(
        'tf1-ckpt',
        assignment_map={'/': 'new_scope/'})
    # `init_from_checkpoint` adds the initializers to these variables.
    # Use `sess.run` to run these initializers.
    sess.run(tf1.global_variables_initializer())

    print("Restored [a, b, c]: ", sess.run([a, b, c]))

Loading with tf1.train.Saver

Unlike init_from_checkpoint, tf.compat.v1.train.Saver runs in both graph and eager mode. The var_list argument optionally accepts a dictionary, except it must map variable names to the tf.Variable object.

# Restoring with tf1.train.Saver (works in both graph and eager):

# A new model with a different scope for the variables.
with tf1.variable_scope('new_scope'):
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
# Initialize the saver with a dictionary with the original variable names:
saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})
saver.restore(sess=None, save_path='tf1-ckpt')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

Loading with tf.train.load_checkpoint

This option is for you if you need precise control over the variable values. Again, this works in both graph and eager modes.

# Restoring with tf.train.load_checkpoint (works in both graph and eager):

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # It may be easier writing a loop if your model has a lot of variables.
    reader = tf.train.load_checkpoint('tf1-ckpt')
    sess.run(a.assign(reader.get_tensor('a')))
    sess.run(b.assign(reader.get_tensor('b')))
    sess.run(c.assign(reader.get_tensor('scoped/c')))
    print("Restored [a, b, c]: ", sess.run([a, b, c]))

Maintaining a TF2 Checkpoint object

If the variable and scope names may change a lot during the migration, then use tf.train.Checkpoint and TF2 checkpoints. TF2 uses the object structure instead of variable names (more details in Changes from TF1 to TF2).

In short, when creating a tf.train.Checkpoint to save or restore checkpoints, make sure it uses the same ordering (for lists) and keys (for dictionaries and keyword arguments to the Checkpoint initializer). Some examples of checkpoint compatibility:

ckpt = tf.train.Checkpoint(foo=[var_a, var_b])

# compatible with ckpt
tf.train.Checkpoint(foo=[var_a, var_b])

# not compatible with ckpt
tf.train.Checkpoint(foo=[var_b, var_a])
tf.train.Checkpoint(bar=[var_a, var_b])

The code samples below show how to use the "same" tf.train.Checkpoint to load variables with different names. First, save a TF2 checkpoint:

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("[a, b, c]: ", sess.run([a, b, c]))

    # Save a TF2 checkpoint
    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    tf2_ckpt_path = ckpt.save('tf2-ckpt')
    print_checkpoint(tf2_ckpt_path)

You can keep using tf.train.Checkpoint even if the variable/scope names change:

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.variable_scope('different_scope'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))

    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    # `assert_consumed` validates that all checkpoint objects are restored from
    # the checkpoint. `run_restore_ops` is required when running in a TF1
    # session.
    ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()

    # Removing `assert_consumed` is fine if you want to skip the validation.
    # ckpt.restore(tf2_ckpt_path).run_restore_ops()

    print("Restored [a, b, c]: ", sess.run([a, b, c]))

And in eager mode:

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

# The keys "scoped" and "unscoped" are no longer relevant, but are used to
# maintain compatibility with the saved checkpoints.
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])

ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

TF2 checkpoints in Estimator

The sections above describe how to maintain checkpoint compatiblity while migrating your model. These concepts also apply for Estimator models, although the way the checkpoint is saved/loaded is slightly different. As you migrate your Estimator model to use TF2 APIs, you may want to switch from TF1 to TF2 checkpoints while the model is still using the estimator. This sections shows how to do so.

tf.estimator.Estimator and MonitoredSession have a saving mechanism called the scaffold, a tf.compat.v1.train.Scaffold object. The Scaffold can contain a tf1.train.Saver or tf.train.Checkpoint, which enables Estimator and MonitoredSession to save TF1- or TF2-style checkpoints.

# A model_fn that saves a TF1 checkpoint
def model_fn_tf1_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=None
  )

!rm -rf est-tf1
est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf1')
print_checkpoint(latest_checkpoint)
# A model_fn that saves a TF2 checkpoint
def model_fn_tf2_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=tf1.train.Scaffold(saver=ckpt)
  )

!rm -rf est-tf2
est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',
                             warm_start_from='est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf2')
print_checkpoint(latest_checkpoint)  

assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4

The final value of v should be 16, after being warm-started from est-tf1, then trained for an additional 5 steps. The train step value doesn't carry over from the warm_start checkpoint.

Checkpointing Keras

Models built with Keras still use tf1.train.Saver and tf.train.Checkpoint to load pre-existing weights. When your model is fully migrated, switch to using model.save_weights and model.load_weights, especially if you are using the ModelCheckpoint callback when training.

Some things you should know about checkpoints and Keras:

Initialization vs Building

Keras models and layers must go through two steps before being fully created. First is the initialization of the Python object: layer = tf.keras.layers.Dense(x). Second is the build step, in which most of the weights are actually created: layer.build(input_shape). You can also build a model by calling it or running a single train, eval, or predict step (the first time only).

If you find that model.load_weights(path).assert_consumed() is raising an error, then it is likely that the model/layers have not been built.

Keras uses TF2 checkpoints

tf.train.Checkpoint(model).write is equivalent to model.save_weights. Same with tf.train.Checkpoint(model).read and model.load_weights. Note that Checkpoint(model) != Checkpoint(model=model).

TF2 checkpoints work with Keras's build() step

tf.train.Checkpoint.restore has a mechanism called deferred restoration which allows tf.Module and Keras objects to store variable values if the variable has not yet been created. This allows initialized models to load weights and build after.

m = YourKerasModel()
status = m.load_weights(path)

# This call builds the model. The variables are created with the restored
# values.
m.predict(inputs)

status.assert_consumed()

Because of this mechanism, we highly recommend that you use TF2 checkpoint loading APIs with Keras models (even when restoring pre-existing TF1 checkpoints into the model mapping shims). See more in the checkpoint guide.

Code Snippets

The snippets below show the TF1/TF2 version compatibility in the checkpoint saving APIs.

Save a TF1 checkpoint in TF2

a = tf.Variable(1.0, name='a')
b = tf.Variable(2.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(3.0, name='c')

saver = tf1.train.Saver(var_list=[a, b, c])
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)

Load a TF1 checkpoint in TF2

a = tf.Variable(0., name='a')
b = tf.Variable(0., name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0., name='c')
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

Save a TF2 checkpoint in TF1

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')
    print_checkpoint(tf2_in_tf1_path)

Load a TF2 checkpoint in TF1

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()
    print("Restored [a, b, c]: ", sess.run([a, b, c]))

Checkpoint conversion

You can convert checkpoints between TF1 and TF2 by loading and re-saving the checkpoints. An alternative is tf.train.load_checkpoint, shown in the code below.

Convert TF1 checkpoint to TF2

def convert_tf1_to_tf2(checkpoint_path, output_prefix):
  """Converts a TF1 checkpoint to TF2.

  To load the converted checkpoint, you must build a dictionary that maps
  variable names to variable objects.
  ```
  ckpt = tf.train.Checkpoint(vars={name: variable})  
  ckpt.restore(converted_ckpt_path)

    ```

    Args:
      checkpoint_path: Path to the TF1 checkpoint.
      output_prefix: Path prefix to the converted checkpoint.

    Returns:
      Path to the converted checkpoint.
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      vars[key] = tf.Variable(reader.get_tensor(key))
    return tf.train.Checkpoint(vars=vars).save(output_prefix)
  ```

Convert the checkpoint saved in the snippet `Save a TF1 checkpoint in TF2`:


Make sure to run the snippet in Save a TF1 checkpoint in TF2.

print_checkpoint('tf1-ckpt-saved-in-eager') converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', 'converted-tf1-to-tf2') print("\n[Converted]") print_checkpoint(converted_path)

Try loading the converted checkpoint.

a = tf.Variable(0.) b = tf.Variable(0.) c = tf.Variable(0.) ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c}) ckpt.restore(converted_path).assert_consumed() print("\nRestored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()]) ```

Convert TF2 checkpoint to TF1

def convert_tf2_to_tf1(checkpoint_path, output_prefix):
  """Converts a TF2 checkpoint to TF1.

  The checkpoint must be saved using a 
  `tf.train.Checkpoint(var_list={name: variable})`

  To load the converted checkpoint with `tf.compat.v1.Saver`:
  ```
  saver = tf.compat.v1.train.Saver(var_list={name: variable}) 

  # An alternative, if the variable names match the keys:
  saver = tf.compat.v1.train.Saver(var_list=[variables]) 
  saver.restore(sess, output_path)

    ```
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      # Get the "name" from the 
      if key.startswith('var_list/'):
        var_name = key.split('/')[1]
        # TF2 checkpoint keys use '/', so if they appear in the user-defined name,
        # they are escaped to '.S'.
        var_name = var_name.replace('.S', '/')
        vars[var_name] = tf.Variable(reader.get_tensor(key))

    return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)
  ```

Convert the checkpoint saved in the snippet `Save a TF2 checkpoint in TF1`:


Make sure to run the snippet in Save a TF2 checkpoint in TF1.

print_checkpoint('tf2-ckpt-saved-in-session-1') converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1', 'converted-tf2-to-tf1') print("\n[Converted]") print_checkpoint(converted_path)

Try loading the converted checkpoint.

with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.variable_scope('scoped'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.Session() as sess: saver = tf1.train.Saver([a, b, c]) saver.restore(sess, converted_path) print("\nRestored [a, b, c]: ", sess.run([a, b, c])) ```