Use TF1.x models in TF2 workflows

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

This guide provides an overview and examples of a modeling code shim that you can employ to use your existing TF1.x models in TF2 workflows such as eager execution, tf.function, and distribution strategies with minimal changes to your modeling code.

Scope of usage

The shim described in this guide is designed for TF1.x models that rely on:

  1. tf.compat.v1.get_variable and tf.compat.v1.variable_scope to control variable creation and reuse, and
  2. Graph-collection based APIs such as tf.compat.v1.global_variables(), tf.compat.v1.trainable_variables, tf.compat.v1.losses.get_regularization_losses(), and tf.compat.v1.get_collection() to keep track of weights and regularization losses

This includes most models built on top of tf.compat.v1.layer, tf.contrib.layers APIs, and TensorFlow-Slim.

The shim is NOT necessary for the following TF1.x models:

  1. Stand-alone Keras models that already track all of their trainable weights and regularization losses via model.trainable_weights and model.losses respectively.
  2. tf.Modules that already track all of their trainable weights via module.trainable_variables, and only create weights if they have not already been created.

These models are likely to work in TF2 with eager execution and tf.functions out-of-the-box.

Setup

Import TensorFlow and other dependencies.

pip uninstall -y -q tensorflow
# Install tf-nightly as the DeterministicRandomTestTool is available only in
# Tensorflow 2.8

pip install -q tf-nightly
import tensorflow as tf
import tensorflow.compat.v1 as v1
import sys
import numpy as np

from contextlib import contextmanager

The track_tf1_style_variables decorator

The key shim described in this guide is tf.compat.v1.keras.utils.track_tf1_style_variables, a decorator that you can use within methods belonging to tf.keras.layers.Layer and tf.Module to track TF1.x-style weights and capture regularization losses.

Decorating a tf.keras.layers.Layer's or tf.Module's call methods with tf.compat.v1.keras.utils.track_tf1_style_variables allows variable creation and reuse via tf.compat.v1.get_variable (and by extension tf.compat.v1.layers) to work correctly inside of the decorated method rather than always creating a new variable on each call. It will also cause the layer or module to implicitly track any weights created or accessed via get_variable inside the decorated method.

In addition to tracking the weights themselves under the standard layer.variable/module.variable/etc. properties, if the method belongs to a tf.keras.layers.Layer, then any regularization losses specified via the get_variable or tf.compat.v1.layers regularizer arguments will get tracked by the layer under the standard layer.losses property.

This tracking mechanism enables using large classes of TF1.x-style model-forward-pass code inside of Keras layers or tf.Modules in TF2 even with TF2 behaviors enabled.

Usage examples

The usage examples below demonstrate the modeling shims used to decorate tf.keras.layers.Layer methods, but except where they are specifically interacting with Keras features they are applicable when decorating tf.Module methods as well.

Layer built with tf.compat.v1.get_variable

Imagine you have a layer implemented directly on top of tf.compat.v1.get_variable as follows:

def dense(self, inputs, units):
  out = inputs
  with tf.compat.v1.variable_scope("dense"):
    # The weights are created with a `regularizer`,
    kernel = tf.compat.v1.get_variable(
        shape=[out.shape[-1], units],
        regularizer=tf.keras.regularizers.L2(),
        initializer=tf.compat.v1.initializers.glorot_normal,
        name="kernel")
    bias = tf.compat.v1.get_variable(
        shape=[units,],
        initializer=tf.compat.v1.initializers.zeros,
        name="bias")
    out = tf.linalg.matmul(out, kernel)
    out = tf.compat.v1.nn.bias_add(out, bias)
  return out

Use the shim to turn it into a layer and call it on inputs.

class DenseLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    out = inputs
    with tf.compat.v1.variable_scope("dense"):
      # The weights are created with a `regularizer`,
      # so the layer should track their regularization losses
      kernel = tf.compat.v1.get_variable(
          shape=[out.shape[-1], self.units],
          regularizer=tf.keras.regularizers.L2(),
          initializer=tf.compat.v1.initializers.glorot_normal,
          name="kernel")
      bias = tf.compat.v1.get_variable(
          shape=[self.units,],
          initializer=tf.compat.v1.initializers.zeros,
          name="bias")
      out = tf.linalg.matmul(out, kernel)
      out = tf.compat.v1.nn.bias_add(out, bias)
    return out

layer = DenseLayer(10)
x = tf.random.normal(shape=(8, 20))
layer(x)

Access the tracked variables and the captured regularization losses like a standard Keras layer.

layer.trainable_variables
layer.losses

To see that the weights get reused each time you call the layer, set all the weights to zero and call the layer again.

print("Resetting variables to zero:", [var.name for var in layer.trainable_variables])

for var in layer.trainable_variables:
  var.assign(var * 0.0)

# Note: layer.losses is not a live view and
# will get reset only at each layer call
print("layer.losses:", layer.losses)
print("calling layer again.")
out = layer(x)
print("layer.losses: ", layer.losses)
out

You can use the converted layer directly in Keras functional model construction as well.

inputs = tf.keras.Input(shape=(20))
outputs = DenseLayer(10)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

x = tf.random.normal(shape=(8, 20))
model(x)

# Access the model variables and regularization losses
model.weights
model.losses

Model built with tf.compat.v1.layers

Imagine you have a layer or model implemented directly on top of tf.compat.v1.layers as follows:

def model(self, inputs, units):
  with tf.compat.v1.variable_scope('model'):
    out = tf.compat.v1.layers.conv2d(
        inputs, 3, 3,
        kernel_regularizer="l2")
    out = tf.compat.v1.layers.flatten(out)
    out = tf.compat.v1.layers.dense(
        out, units,
        kernel_regularizer="l2")
    return out

Use the shim to turn it into a layer and call it on inputs.

class CompatV1LayerModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

layer = CompatV1LayerModel(10)
x = tf.random.normal(shape=(8, 5, 5, 5))
layer(x)

Access the tracked variables and captured regularization losses like a standard Keras layer.

layer.trainable_variables
layer.losses

To see that the weights get reused each time you call the layer, set all the weights to zero and call the layer again.

print("Resetting variables to zero:", [var.name for var in layer.trainable_variables])

for var in layer.trainable_variables:
  var.assign(var * 0.0)

out = layer(x)
print("layer.losses: ", layer.losses)
out

You can use the converted layer directly in Keras functional model construction as well.

inputs = tf.keras.Input(shape=(5, 5, 5))
outputs = CompatV1LayerModel(10)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

x = tf.random.normal(shape=(8, 5, 5, 5))
model(x)
# Access the model variables and regularization losses
model.weights
model.losses

Capture batch normalization updates and model training args

In TF1.x, you perform batch normalization like this:

  x_norm = tf.compat.v1.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)
  train_op = optimizer.minimize(loss)
  train_op = tf.group([train_op, update_ops])

Note that:

  1. The batch normalization moving average updates are tracked by get_collection which was called separately from the layer
  2. tf.compat.v1.layers.batch_normalization requires a training argument (generally called is_training when using TF-Slim batch normalization layers)

In TF2, due to eager execution and automatic control dependencies, the batch normalization moving average updates will be executed right away. There is no need to separately collect them from the updates collection and add them as explicit control dependencies.

Additionally, if you give your tf.keras.layers.Layer's forward pass method a training argument, Keras will be able to pass the current training phase and any nested layers to it just like it does for any other layer. See the API docs for tf.keras.Model for more information on how Keras handles the training argument.

If you are decorating tf.Module methods, you need to make sure to manually pass all training arguments as needed. However, the batch normalization moving average updates will still be applied automatically with no need for explicit control dependencies.

The following code snippets demonstrate how to embed batch normalization layers in the shim and how using it in a Keras model works (applicable to tf.keras.layers.Layer).

class CompatV1BatchNorm(tf.keras.layers.Layer):

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    print("Forward pass called with `training` =", training)
    with v1.variable_scope('batch_norm_layer'):
      return v1.layers.batch_normalization(x, training=training)
print("Constructing model")
inputs = tf.keras.Input(shape=(5, 5, 5))
outputs = CompatV1BatchNorm()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

print("Calling model in inference mode")
x = tf.random.normal(shape=(8, 5, 5, 5))
model(x, training=False)

print("Moving average variables before training: ",
      {var.name: var.read_value() for var in model.non_trainable_variables})

# Notice that when running TF2 and eager execution, the batchnorm layer directly
# updates the moving averages while training without needing any extra control
# dependencies
print("calling model in training mode")
model(x, training=True)

print("Moving average variables after training: ",
      {var.name: var.read_value() for var in model.non_trainable_variables})

Variable-scope based variable reuse

Any variable creations in the forward pass based on get_variable will maintain the same variable naming and reuse semantics that variable scopes have in TF1.x. This is true as long as you have at least one non-empty outer scope for any tf.compat.v1.layers with auto-generated names, as mentioned above.

Eager execution & tf.function

As seen above, decorated methods for tf.keras.layers.Layer and tf.Module run inside of eager execution and are also compatible with tf.function. This means you can use pdb and other interactive tools to step through your forward pass as it is running.

Distribution strategies

Calls to get_variable inside of @track_tf1_style_variables-decorated layer or module methods use standard tf.Variable variable creations under the hood. This means you can use them with the various distribution strategies available with tf.distribute such as MirroredStrategy and TPUStrategy.

Nesting tf.Variables, tf.Modules, tf.keras.layers & tf.keras.models in decorated calls

Decorating your layer call in tf.compat.v1.keras.utils.track_tf1_style_variables will only add automatic implicit tracking of variables created (and reused) via tf.compat.v1.get_variable. It will not capture weights directly created by tf.Variable calls, such as those used by typical Keras layers and most tf.Modules. This section describes how to handle these nested cases.

(Pre-existing usages) tf.keras.layers and tf.keras.models

For pre-existing usages of nested Keras layers and models, use tf.compat.v1.keras.utils.get_or_create_layer. This is only recommended for easing migration of existing TF1.x nested Keras usages; new code should use explicit attribute setting as described below for tf.Variables and tf.Modules.

To use tf.compat.v1.keras.utils.get_or_create_layer, wrap the code that constructs your nested model into a method, and pass it in to the method. Example:

class NestedModel(tf.keras.Model):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  def build_model(self):
    inp = tf.keras.Input(shape=(5, 5))
    dense_layer = tf.keras.layers.Dense(
        10, name="dense", kernel_regularizer="l2",
        kernel_initializer=tf.compat.v1.ones_initializer())
    model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
    return model

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    # Get or create a nested model without assigning it as an explicit property
    model = tf.compat.v1.keras.utils.get_or_create_layer(
        "dense_model", self.build_model)
    return model(inputs)

layer = NestedModel(10)
layer(tf.ones(shape=(5,5)))

This method ensures that these nested layers are correctly reused and tracked by tensorflow. Note that the @track_tf1_style_variables decorator is still required on the appropriate method. The model builder method passed into get_or_create_layer (in this case, self.build_model), should take no arguments.

Weights are tracked:

assert len(layer.weights) == 2
weights = {x.name: x for x in layer.variables}

assert set(weights.keys()) == {"dense/bias:0", "dense/kernel:0"}

layer.weights

And regularization loss as well:

tf.add_n(layer.losses)

Incremental migration: tf.Variables and tf.Modules

If you need to embed tf.Variable calls or tf.Modules in your decorated methods (for example, if you are following the incremental migration to non-legacy TF2 APIs described later in this guide), you still need to explicitly track these, with the following requirements:

  • Explicitly make sure that the variable/module/layer is only created once
  • Explicitly attach them as instance attributes just as you would when defining a typical module or layer
  • Explicitly reuse the already-created object in follow-on calls

This ensures that weights are not created new each call and are correctly reused. Additionally, this also ensures that existing weights and regularization losses get tracked.

Here is an example of how this could look:

class NestedLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def __call__(self, inputs):
    out = inputs
    with tf.compat.v1.variable_scope("inner_dense"):
      # The weights are created with a `regularizer`,
      # so the layer should track their regularization losses
      kernel = tf.compat.v1.get_variable(
          shape=[out.shape[-1], self.units],
          regularizer=tf.keras.regularizers.L2(),
          initializer=tf.compat.v1.initializers.glorot_normal,
          name="kernel")
      bias = tf.compat.v1.get_variable(
          shape=[self.units,],
          initializer=tf.compat.v1.initializers.zeros,
          name="bias")
      out = tf.linalg.matmul(out, kernel)
      out = tf.compat.v1.nn.bias_add(out, bias)
    return out

class WrappedDenseLayer(tf.keras.layers.Layer):

  def __init__(self, units, **kwargs):
    super().__init__(**kwargs)
    self.units = units
    # Only create the nested tf.variable/module/layer/model
    # once, and then reuse it each time!
    self._dense_layer = NestedLayer(self.units)

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('outer'):
      outputs = tf.compat.v1.layers.dense(inputs, 3)
      outputs = tf.compat.v1.layers.dense(inputs, 4)
      return self._dense_layer(outputs)

layer = WrappedDenseLayer(10)

layer(tf.ones(shape=(5, 5)))

Note that explicit tracking of the nested module is needed even though it is decorated with the track_tf1_style_variables decorator. This is because each module/layer with decorated methods has its own variable store associated with it.

The weights are correctly tracked:

assert len(layer.weights) == 6
weights = {x.name: x for x in layer.variables}

assert set(weights.keys()) == {"outer/inner_dense/bias:0",
                               "outer/inner_dense/kernel:0",
                               "outer/dense/bias:0",
                               "outer/dense/kernel:0",
                               "outer/dense_1/bias:0",
                               "outer/dense_1/kernel:0"}

layer.trainable_weights

As well as regularization loss:

layer.losses

Note that if the NestedLayer were a non-Keras tf.Module instead, variables would still be tracked but regularization losses would not be automatically tracked, so you would have to explicitly track them separately.

Guidance on variable names

Explicit tf.Variable calls and Keras layers use a different layer name / variable name autogeneration mechanism than you may be used to from the combination of get_variable and variable_scopes. Although the shim will make your variable names match for variables created by get_variable even when going from TF1.x graphs to TF2 eager execution & tf.function, it cannot guarantee the same for the variable names generated for tf.Variable calls and Keras layers that you embed within your method decorators. It is even possible for multiple variables to share the same name in TF2 eager execution and tf.function.

You should take special care with this when following the sections on validating correctness and mapping TF1.x checkpoints later on in this guide.

Using tf.compat.v1.make_template in the decorated method

It is highly recommended you directly use tf.compat.v1.keras.utils.track_tf1_style_variables instead of using tf.compat.v1.make_template, as it is a thinner layer on top of TF2.

Follow the guidance in this section for prior TF1.x code that was already relying on tf.compat.v1.make_template.

Because tf.compat.v1.make_template wraps code that uses get_variable, the track_tf1_style_variables decorator allows you to use these templates in layer calls and successfully track the weights and regularization losses.

However, do make sure to call make_template only once and then reuse the same template in each layer call. Otherwise, a new template will be created each time you call the layer along with a new set of variables.

For example,

class CompatV1TemplateScaleByY(tf.keras.layers.Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    def my_op(x, scalar_name):
      var1 = tf.compat.v1.get_variable(scalar_name,
                            shape=[],
                            regularizer=tf.compat.v1.keras.regularizers.L2(),
                            initializer=tf.compat.v1.constant_initializer(1.5))
      return x * var1
    self.scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op, scalar_name='y')

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('layer'):
      # Using a scope ensures the `scale_by_y` name will not be incremented
      # for each instantiation of the layer.
      return self.scale_by_y(inputs)

layer = CompatV1TemplateScaleByY()

out = layer(tf.ones(shape=(2, 3)))
print("weights:", layer.weights)
print("regularization loss:", layer.losses)
print("output:", out)

Incremental migration to Native TF2

As mentioned earlier, track_tf1_style_variables allows you to mix TF2-style object-oriented tf.Variable/tf.keras.layers.Layer/tf.Module usage with legacy tf.compat.v1.get_variable/tf.compat.v1.layers-style usage inside of the same decorated module/layer.

This means that after you have made your TF1.x model fully-TF2-compatible, you can write all new model components with native (non-tf.compat.v1) TF2 APIs and have them interoperate with your older code.

However, if you continue to modify your older model components, you may also choose to incrementally switch your legacy-style tf.compat.v1 usage over to the purely-native object-oriented APIs that are recommended for newly written TF2 code.

tf.compat.v1.get_variable usage can be replaced with either self.add_weight calls if you are decorating a Keras layer/model, or with tf.Variable calls if you are decorating Keras objects or tf.Modules.

Both functional-style and object-oriented tf.compat.v1.layers can generally be replaced with the equivalent tf.keras.layers layer with no argument changes required.

You may also consider chunks parts of your model or common patterns into individual layers/modules during your incremental move to purely-native APIs, which may themselves use track_tf1_style_variables.

A note on Slim and contrib.layers

A large amount of older TF 1.x code uses the Slim library, which was packaged with TF 1.x as tf.contrib.layers. Converting code using Slim to native TF 2 is more involved than converting v1.layers. In fact, it may make sense to convert your Slim code to v1.layers first, then convert to Keras. Below is some general guidance for converting Slim code.

  • Ensure all arguments are explicit. Remove arg_scopes if possible. If you still need to use them, split normalizer_fn and activation_fn into their own layers.
  • Separable conv layers map to one or more different Keras layers (depthwise, pointwise, and separable Keras layers).
  • Slim and v1.layers have different argument names and default values.
  • Note that some arguments have different scales.

Migration to Native TF2 ignoring checkpoint compatibility

The following code sample demonstrates an incremental move of a model to purely-native APIs without considering checkpoint compatibility.

class CompatModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dropout(out, training=training)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

Next, replace the compat.v1 APIs with their native object-oriented equivalents in a piecewise manner. Start by switching the convolution layer to a Keras object created in the layer constructor.

class PartiallyMigratedModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_layer(inputs)
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dropout(out, training=training)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

Use the v1.keras.utils.DeterministicRandomTestTool class to verify that this incremental change leaves the model with the same behavior as before.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = CompatModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  original_output = layer(inputs)

  # Grab the regularization loss as well
  original_regularization_loss = tf.math.add_n(layer.losses)

print(original_regularization_loss)
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = PartiallyMigratedModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

You have now replaced all of the individual compat.v1.layers with native Keras layers.

class NearlyFullyNativeModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")
    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_layer = tf.keras.layers.Dense(
      self.units,
      kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_layer(inputs)
      out = self.flatten_layer(out)
      out = self.dense_layer(out)
      return out
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = NearlyFullyNativeModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

Finally, remove both any remaining (no-longer-needed) variable_scope usage and the track_tf1_style_variables decorator itself.

You are now left with a version of the model that uses entirely native APIs.

class FullyNativeModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")
    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_layer = tf.keras.layers.Dense(
      self.units,
      kernel_regularizer="l2")

  def call(self, inputs):
    out = self.conv_layer(inputs)
    out = self.flatten_layer(out)
    out = self.dense_layer(out)
    return out
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = FullyNativeModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

Maintaining checkpoint compatibility during migration to Native TF2

The above migration process to native TF2 APIs changed both the variable names (as Keras APIs produce very different weight names), and the object-oriented paths that point to different weights in the model. The impact of these changes is that they will have broken both any existing TF1-style name-based checkpoints or TF2-style object-oriented checkpoints.

However, in some cases, you might be able to take your original name-based checkpoint and find a mapping of the variables to their new names with approaches like the one detailed in the Reusing TF1.x checkpoints guide.

Some tips to making this feasible are as follows:

  • Variables still all have a name argument you can set.
  • Keras models also take a name argument as which they set as the prefix for their variables.
  • The v1.name_scope function can be used to set variable name prefixes. This is very different from tf.variable_scope. It only affects names, and doesn't track variables and reuse.

With the above pointers in mind, the following code samples demonstrate a workflow you can adapt to your code to incrementally update part of a model while simultaneously updating checkpoints.

  1. Begin by switching functional-style tf.compat.v1.layers to their object-oriented versions.
class FunctionalStyleCompatModel(tf.keras.layers.Layer):

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.conv2d(
          out, 4, 4,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = FunctionalStyleCompatModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
  1. Next, assign the compat.v1.layer objects and any variables created by compat.v1.get_variable as properties of the tf.keras.layers.Layer/tf.Module object whose method is decorated with track_tf1_style_variables (note that any object-oriented TF2 style checkpoints will now save out both a path by variable name and the new object-oriented path).
class OOStyleCompatModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.compat.v1.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.compat.v1.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      out = self.conv_2(out)
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = OOStyleCompatModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
  1. Resave a loaded checkpoint at this point to save out paths both by the variable name (for compat.v1.layers), or by the object-oriented object graph.
weights = {v.name: v for v in layer.weights}
assert weights['model/conv2d/kernel:0'] is layer.conv_1.kernel
assert weights['model/conv2d_1/bias:0'] is layer.conv_2.bias
  1. You can now swap out the object-oriented compat.v1.layers for native Keras layers while still being able to load the recently-saved checkpoint. Ensure that you preserve variable names for the remaining compat.v1.layers by still recording the auto-generated variable_scopes of the replaced layers. These switched layers/variables will now only use the object attribute path to the variables in the checkpoint instead of the variable name path.

In general, you can replace usage of compat.v1.get_variable in variables attached to properties by:

  • Switching them to using tf.Variable, OR
  • Updating them by using tf.keras.layers.Layer.add_weight. Note that if you are not switching all layers in one go this may change auto-generated layer/variable naming for the remaining compat.v1.layers that are missing a name argument. If that is the case, you must keep the variable names for remaining compat.v1.layers the same by manually opening and closing a variable_scope corresponding to the removed compat.v1.layer's generated scope name. Otherwise the paths from existing checkpoints may conflict and checkpoint loading will behave incorrectly.
def record_scope(scope_name):
  """Record a variable_scope to make sure future ones get incremented."""
  with tf.compat.v1.variable_scope(scope_name):
    pass

class PartiallyNativeKerasLayersModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.keras.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.keras.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      record_scope('conv2d') # Only needed if follow-on compat.v1.layers do not pass a `name` arg
      out = self.conv_2(out)
      record_scope('conv2d_1') # Only needed if follow-on compat.v1.layers do not pass a `name` arg
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = PartiallyNativeKerasLayersModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]

Saving a checkpoint out at this step after constructing the variables will make it contain only the currently-available object paths.

Ensure you record the scopes of the removed compat.v1.layers to preserve the auto-generated weight names for the remaining compat.v1.layers.

weights = set(v.name for v in layer.weights)
assert 'model/conv2d_2/kernel:0' in weights
assert 'model/conv2d_2/bias:0' in weights
  1. Repeat the above steps until you have replaced all the compat.v1.layers and compat.v1.get_variables in your model with fully-native equivalents.
class FullyNativeKerasLayersModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.keras.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.keras.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")
    self.conv_3 = tf.keras.layers.Conv2D(
          5, 5,
          kernel_regularizer="l2")


  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      out = self.conv_2(out)
      out = self.conv_3(out)
      return out

layer = FullyNativeKerasLayersModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]

Remember to test to make sure the newly updated checkpoint still behaves as you expect. Apply the techniques described in the validate numerical correctness guide at every incremental step of this process to ensure your migrated code runs correctly.

Handling TF1.x to TF2 behavior changes not covered by the modeling shims

The modeling shims described in this guide can make sure that variables, layers, and regularization losses created with get_variable, tf.compat.v1.layers, and variable_scope semantics continue to work as before when using eager execution and tf.function, without having to rely on collections.

This does not cover all TF1.x-specific semantics that your model forward passes may be relying on. In some cases, the shims might be insufficient to get your model forward pass running in TF2 on their own. Read the TF1.x vs TF2 behaviors guide to learn more about the behavioral differences between TF1.x and TF2.