TF2 ワークフローで TF1.x モデルを使用する

TensorFlow.org で表示 Colab で実行 GitHub で表示 ノートブックをダウンロード

このガイドでは、モデリングコードへの変更を最小限に抑えて、Eager execution、tf.function、分散ストラテジーなどの TF2 ワークフローで既存の TF1.x モデルを使用するために使用できるモデリングコード Shim の概要と例を示します。

利用範囲

このガイドで説明されている Shim は、以下に依存する TF1.x モデルのために設計されています。

  1. 変数の作成と再利用を制御する tf.compat.v1.get_variabletf.compat.v1.variable_scope、および
  2. 重みと正則化損失を追跡するための tf.compat.v1.global_variables()tf.compat.v1.trainable_variablestf.compat.v1.lossesget_regularization_losses()、および tf.compat.v1.get_collection() などのグラフコレクションベースの API。

tf.compat.v1.layertf.contrib.layers API、および TensorFlow-Slim の上に構築されたほとんどのモデルで使用できます。

Shim は、次の TF1.x モデルでは必要ありません

  1. model.trainable_weightsmodel.losses を介して、すべてのトレーニング可能な重みと正則化損失を既に追跡しているスタンドアロンの Keras モデル。
  2. module.trainable_variables を介してトレーニング可能なすべての重みを既に追跡し、まだ作成されていない場合にのみ重みを作成する tf.Module

これらのモデルは、Eager execution と tf.function を使用して TF2 ですぐに使える可能性があります。

セットアップ

TensorFlow と他の依存関係をインポートします。

pip uninstall -y -q tensorflow
# Install tf-nightly as the DeterministicRandomTestTool is available only in
# Tensorflow 2.8

pip install -q tf-nightly
import tensorflow as tf
import tensorflow.compat.v1 as v1
import sys
import numpy as np

from contextlib import contextmanager
2022-12-14 22:45:45.731902: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay

track_tf1_style_variables デコレータ

このガイドで説明されている重要な Shim は tf.keras.layers.Layer に属するメソッド内で使用できるデコレータ tf.compat.v1.keras.utils.track_tf1_style_variables とTF1.x スタイルの重みを追跡し、正則化の損失を捕捉する tf.Module です。

tf.keras.layers.Layer または tf.Module の呼び出しメソッドを tf.compat.v1.keras.utils.track_tf1_style_variables でデコレートすると、tf.compat.v1.get_variable(および拡張 tf.compat.v1.layers)を介して変数の作成と再利用がデコレートされたメソッド内で正しく機能します。呼び出しごとに常に新しい変数を作成する必要はありません。また、レイヤーまたはモジュールは、デコレートされたメソッド内で get_variable を介して作成またはアクセスされた重みを暗黙的に追跡します。

標準の layer.variable/module.variable などで重み自体を追跡することに加えて、 メソッドが tf.keras.layers.Layer に属する場合、get_variable または tf.compat.v1.layers 正則化引数は、標準の layer.losses プロパティの下でレイヤーによって追跡されます。

この追跡メカニズムにより、TF2 の動作が有効になっている場合でも、Keras レイヤーまたは TF2 の tf.Module 内で TF1.x スタイルのモデルフォワードパスコードの大規模なクラスを使用できます。

使用例

以下の使用例は、tf.keras.layers.Layer メソッドをデコレートするために使用されるモデリング Shim を示していますが、Keras 機能と特に相互作用する場合を除き tf.Module をデコレートするときも適用できます。

tf.compat.v1.get_variable で構築されたレイヤー

以下のように、tf.compat.v1.get_variable の上に直接実装されたレイヤーがあるとします。

def dense(self, inputs, units):
  out = inputs
  with tf.compat.v1.variable_scope("dense"):
    # The weights are created with a `regularizer`,
    kernel = tf.compat.v1.get_variable(
        shape=[out.shape[-1], units],
        regularizer=tf.keras.regularizers.L2(),
        initializer=tf.compat.v1.initializers.glorot_normal,
        name="kernel")
    bias = tf.compat.v1.get_variable(
        shape=[units,],
        initializer=tf.compat.v1.initializers.zeros,
        name="bias")
    out = tf.linalg.matmul(out, kernel)
    out = tf.compat.v1.nn.bias_add(out, bias)
  return out

Shim を使用してレイヤーに変換し、入力で呼び出します。

class DenseLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    out = inputs
    with tf.compat.v1.variable_scope("dense"):
      # The weights are created with a `regularizer`,
      # so the layer should track their regularization losses
      kernel = tf.compat.v1.get_variable(
          shape=[out.shape[-1], self.units],
          regularizer=tf.keras.regularizers.L2(),
          initializer=tf.compat.v1.initializers.glorot_normal,
          name="kernel")
      bias = tf.compat.v1.get_variable(
          shape=[self.units,],
          initializer=tf.compat.v1.initializers.zeros,
          name="bias")
      out = tf.linalg.matmul(out, kernel)
      out = tf.compat.v1.nn.bias_add(out, bias)
    return out

layer = DenseLayer(10)
x = tf.random.normal(shape=(8, 20))
layer(x)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_203546/795621215.py:7: The name tf.keras.utils.track_tf1_style_variables is deprecated. Please use tf.compat.v1.keras.utils.track_tf1_style_variables instead.
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[ 0.4686809 , -1.1716666 , -0.20547277, -0.6444645 ,  0.7130223 ,
        -0.00451446, -1.6933861 , -1.5964466 , -0.23282064, -0.08314091],
       [-1.1907233 , -3.2358186 ,  1.2351677 , -0.97008264, -1.2942133 ,
         1.7218621 , -2.1726809 , -0.2069498 ,  0.5149553 , -0.0072787 ],
       [ 0.9601712 ,  1.9170468 , -2.7247353 , -0.35632032, -0.5812271 ,
        -2.8715076 , -1.0919665 , -4.299157  , -1.5660018 , -1.6987164 ],
       [-0.16025108, -1.720531  , -0.41742757, -0.33518863,  1.0489087 ,
         0.6075638 , -0.01579201, -0.12971917,  1.7975792 ,  1.003011  ],
       [-0.20135996,  1.402952  , -0.3722505 ,  0.67277575, -0.01278794,
        -1.5798886 ,  1.7130005 , -2.730586  , -0.49501845, -1.2986678 ],
       [ 1.4841177 , -0.49532753, -0.01718462,  0.36023936, -0.08543369,
         0.7038621 ,  1.2424662 , -0.13748264, -0.84791577,  0.55169487],
       [-0.8036947 , -0.35385704, -0.70038223,  1.2686557 ,  0.37936342,
        -0.71443236, -0.95396805, -1.4449823 , -1.0922728 , -1.3701073 ],
       [ 0.18006533, -2.3372343 , -0.06073838,  1.3101625 ,  0.35873908,
         0.13722438, -1.675015  ,  1.286251  , -0.5555371 , -0.89756334]],
      dtype=float32)>

標準の Keras レイヤーのように、追跡された変数とキャプチャされた正則化損失にアクセスします。

layer.trainable_variables
layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.13565482>]

レイヤーを呼び出すたびに重みが再利用されることを確認するには、すべての重みをゼロに設定して、レイヤーを再度呼び出します。

print("Resetting variables to zero:", [var.name for var in layer.trainable_variables])

for var in layer.trainable_variables:
  var.assign(var * 0.0)

# Note: layer.losses is not a live view and
# will get reset only at each layer call
print("layer.losses:", layer.losses)
print("calling layer again.")
out = layer(x)
print("layer.losses: ", layer.losses)
out
Resetting variables to zero: ['dense/bias:0', 'dense/kernel:0']
layer.losses: [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
calling layer again.
layer.losses:  [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

変換されたレイヤーは、Keras 機能モデルの構築でも直接使用できます。

inputs = tf.keras.Input(shape=(20))
outputs = DenseLayer(10)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

x = tf.random.normal(shape=(8, 20))
model(x)

# Access the model variables and regularization losses
model.weights
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.13877623>]

tf.compat.v1.layers で構築されたモデル

以下のように、レイヤーまたはモデルが tf.compat.v1.layers の上に直接実装されているとします。

def model(self, inputs, units):
  with tf.compat.v1.variable_scope('model'):
    out = tf.compat.v1.layers.conv2d(
        inputs, 3, 3,
        kernel_regularizer="l2")
    out = tf.compat.v1.layers.flatten(out)
    out = tf.compat.v1.layers.dense(
        out, units,
        kernel_regularizer="l2")
    return out

Shim を使用してレイヤーに変換し、入力で呼び出します。

class CompatV1LayerModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

layer = CompatV1LayerModel(10)
x = tf.random.normal(shape=(8, 5, 5, 5))
layer(x)
/tmpfs/tmp/ipykernel_203546/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_203546/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_203546/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[ 0.8995241 , -0.22181192,  0.16006124, -1.1080818 ,  0.6378982 ,
         1.850347  , -0.03218482, -0.32725865,  0.4217392 ,  1.7955269 ],
       [ 0.71358705,  0.7316396 ,  0.2763209 , -0.25087613, -0.66956085,
         0.03913212, -2.325179  ,  1.4672816 ,  0.697696  ,  0.03922887],
       [-0.21437314, -0.4441604 ,  1.287027  , -0.5215813 ,  0.96731186,
         2.1654997 , -1.888258  ,  3.0895581 , -1.636878  ,  0.6699464 ],
       [ 0.21570104,  0.9184797 , -1.8048897 ,  0.33267093,  0.02654031,
         2.764593  , -0.25900927, -0.7398262 , -1.1664623 , -0.16144866],
       [-1.8594477 , -0.01630855,  2.6787105 , -0.07748097,  1.5277824 ,
         0.654403  , -1.6919625 , -1.9184573 ,  4.0260487 ,  0.03399533],
       [-1.3357637 , -0.4211569 ,  0.87072176, -1.2892274 , -1.312388  ,
        -1.1570771 ,  0.96706843, -0.2821595 ,  1.9245236 ,  0.3725299 ],
       [-2.1367514 , -1.3135502 ,  1.3775344 , -0.5111486 , -0.4388187 ,
        -1.8627679 , -1.1438559 ,  0.4970094 ,  2.4051938 , -0.22061276],
       [ 0.16344474,  0.8936699 ,  0.54490006,  0.6468228 ,  1.8865356 ,
         1.0983261 , -0.7020545 ,  0.29437292,  0.19727364, -0.527163  ]],
      dtype=float32)>

警告: 安全上の理由から、空でない文字列 variable_scope 内にすべての tf.compat.v1.layers を配置してください。これは、自動生成された名前を持つ tf.compat.v1.layers が、変数スコープの外で常に名前を自動インクリメントするためです。これは、レイヤー/モジュールを呼び出すたびに、要求された変数名が一致しないことを意味します。したがって、既に作成された重みを再利用するのではなく、呼び出しごとに新しい変数のセットを作成します。

標準の Keras レイヤーのように、追跡された変数とキャプチャされた正則化損失にアクセスします。

layer.trainable_variables
layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.03467329>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.14976557>]

レイヤーを呼び出すたびに重みが再利用されることを確認するには、すべての重みをゼロに設定して、レイヤーを再度呼び出します。

print("Resetting variables to zero:", [var.name for var in layer.trainable_variables])

for var in layer.trainable_variables:
  var.assign(var * 0.0)

out = layer(x)
print("layer.losses: ", layer.losses)
out
Resetting variables to zero: ['model/conv2d/bias:0', 'model/conv2d/kernel:0', 'model/dense/bias:0', 'model/dense/kernel:0']
layer.losses:  [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
/tmpfs/tmp/ipykernel_203546/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_203546/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_203546/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

変換されたレイヤーは、Keras 機能モデルの構築でも直接使用できます。

inputs = tf.keras.Input(shape=(5, 5, 5))
outputs = CompatV1LayerModel(10)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

x = tf.random.normal(shape=(8, 5, 5, 5))
model(x)
/tmpfs/tmp/ipykernel_203546/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
/tmpfs/tmp/ipykernel_203546/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_203546/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[ 0.6098672 ,  0.97505224,  0.65215063, -3.064506  , -1.5792439 ,
         0.15776676,  1.6259285 , -3.2481084 , -1.2247274 ,  1.8893697 ],
       [-0.44124675,  1.1242774 , -1.0716134 , -0.75544125, -0.97522974,
         3.3298619 ,  0.6500565 ,  0.23895854, -1.3651588 , -0.37240303],
       [-0.47285593, -0.3575668 ,  1.9535047 , -0.98463476,  0.62686014,
        -2.615439  , -1.5498544 ,  0.3025058 , -1.4009817 ,  1.7124442 ],
       [-0.9377699 , -0.99163735,  2.5051246 ,  1.6605344 ,  2.167434  ,
        -1.3772843 , -0.9390738 , -1.8639736 ,  4.0830255 ,  1.8045989 ],
       [-1.5968541 , -1.0940666 ,  2.2500196 , -0.5151663 ,  1.1103777 ,
        -1.5564997 , -2.5432277 ,  0.25040847,  1.2886575 ,  0.84138167],
       [-0.8662586 , -0.8557035 , -1.3427409 ,  0.6566953 ,  1.4976443 ,
        -1.6714022 , -3.023892  ,  0.41466054,  0.11158413,  0.8047291 ],
       [ 0.17170793, -0.79710865,  0.20337993,  1.6620314 ,  0.3963801 ,
        -1.1913121 ,  1.7417206 ,  2.588733  , -0.05603534, -1.3200552 ],
       [-0.13649988,  1.162786  ,  0.27264208, -0.9247199 , -0.61450845,
         1.17097   ,  1.3044865 , -0.92457485, -1.1188524 , -0.1664241 ]],
      dtype=float32)>
# Access the model variables and regularization losses
model.weights
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.039181717>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.1415849>]

バッチ正則化の更新とモデルの training 引数をキャプチャする

TF1.x では、次のようにバッチ正則化を実行します。

  x_norm = tf.compat.v1.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)
  train_op = optimizer.minimize(loss)
  train_op = tf.group([train_op, update_ops])

注意点:

  1. バッチ正則化移動平均の更新は、レイヤーとは別に呼び出された get_collection によって追跡されます
  2. tf.compat.v1.layers.batch_normalization には training 引数が必要です(通常、TF-Slim バッチ正則化レイヤーを使用する場合は is_training と呼ばれます)

TF2 では、Eager execution と自動制御の依存関係により、バッチ正則化移動平均更新がすぐに実行されます。これらを更新コレクションから個別に収集し、明示的な制御の依存関係として追加する必要はありません。

さらに、tf.keras.layers.Layer のフォワードパスメソッドに training 引数を与えると、Keras はその時点のトレーニングフェーズとネストされたレイヤーを他のレイヤーと同じように渡すことができます。Keras が training 引数を処理する方法の詳細については、tf.keras.Model の API ドキュメントを参照してください。

tf.Module メソッドをデコレートする場合は、必要に応じてすべての training 引数を手動で渡す必要があります。ただし、バッチ正則化移動平均更新は、明示的な制御依存関係を必要とせずに自動的に適用されます。

次のコードスニペットは、Shim にバッチ正則化レイヤーを埋め込む方法と、それを Keras モデルで使用する方法を示しています (tf.keras.layers.Layer に適用可能)。

class CompatV1BatchNorm(tf.keras.layers.Layer):

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    print("Forward pass called with `training` =", training)
    with v1.variable_scope('batch_norm_layer'):
      return v1.layers.batch_normalization(x, training=training)
print("Constructing model")
inputs = tf.keras.Input(shape=(5, 5, 5))
outputs = CompatV1BatchNorm()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

print("Calling model in inference mode")
x = tf.random.normal(shape=(8, 5, 5, 5))
model(x, training=False)

print("Moving average variables before training: ",
      {var.name: var.read_value() for var in model.non_trainable_variables})

# Notice that when running TF2 and eager execution, the batchnorm layer directly
# updates the moving averages while training without needing any extra control
# dependencies
print("calling model in training mode")
model(x, training=True)

print("Moving average variables after training: ",
      {var.name: var.read_value() for var in model.non_trainable_variables})
Constructing model
Forward pass called with `training` = None
/tmpfs/tmp/ipykernel_203546/3053504896.py:7: UserWarning: `tf.layers.batch_normalization` is deprecated and will be removed in a future version. Please use `tf.keras.layers.BatchNormalization` instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.BatchNormalization` documentation).
  return v1.layers.batch_normalization(x, training=training)
Calling model in inference mode
Forward pass called with `training` = False
Moving average variables before training:  {'batch_norm_layer/batch_normalization/moving_mean:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=array([0., 0., 0., 0., 0.], dtype=float32)>, 'batch_norm_layer/batch_normalization/moving_variance:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=array([1., 1., 1., 1., 1.], dtype=float32)>}
calling model in training mode
Forward pass called with `training` = True
Moving average variables after training:  {'batch_norm_layer/batch_normalization/moving_mean:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=
array([-0.0001415 , -0.00039963, -0.00023034,  0.00011454,  0.000515  ],
      dtype=float32)>, 'batch_norm_layer/batch_normalization/moving_variance:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=
array([1.0010952 , 0.99965096, 1.000891  , 0.99932873, 1.000744  ],
      dtype=float32)>}

変数スコープに基づく変数の再利用

get_variable に基づくフォワードパスでの変数の作成は、TF1.x の変数スコープと同じ変数の命名と再利用のセマンティクスを維持します。上記のように、自動生成された名前を持つ任意の tf.compat.v1.layers に対して少なくとも 1 つの空でない外部スコープがある必要があります。

注意: 命名と再利用は、単一のレイヤー/モジュールインスタンス内に限定されます。1 つの Shim でデコレートされたレイヤーまたはモジュール内の get_variable への呼び出しは、レイヤーまたはモジュール内で作成された変数を参照できません。get_variable を介して変数にアクセスするのではなく、必要に応じて他の変数への Python 参照を直接使用することで、これを回避できます。

Eager execution と tf.function

上記のように、tf.keras.layers.Layertf.Module のデコレートされたメソッドは、Eager execution の内部で実行され、tf.function とも互換性があります。これは、pdb やその他の対話型ツールを使用して、実行中のフォワードパスをステップ実行できることを意味します

警告: tf.functionから Shim でデコレートされたレイヤー/モジュール メソッドを呼び出すことは完全に安全ですが、これらの tf.functionget_variable 呼び出しが含まれている場合、Shim でデコレートされたメソッド内に tf.function を配置することは安全ではありません。tf.function を入力すると、variable_scope がリセットされ、Shim が模倣する TF1.x スタイルの変数スコープに基づく変数の再利用が、この設定で失敗します。

分散ストラテジー

@track_tf1_style_variables でデコレートされたレイヤーまたはモジュールメソッド内の get_variable への呼び出しは、内部で標準の tf.Variable 変数の作成を使用します。これは、MirroredStrategyTPUStrategy など、tf.distribute で利用可能なさまざまな分散ストラテジーでそれらを使用できることを意味します。

デコレートされた呼び出しで tf.Variabletf.Moduletf.keras.layers および tf.keras.models をネストする

tf.compat.v1.keras.utils.track_tf1_style_variables でレイヤー呼び出しをデコレートすると、tf.compat.v1.get_variable を介して作成された(および再利用された)変数の自動暗黙的追跡のみが追加されます。典型的な Keras レイヤーやほとんどの tf.Module で使用される tf.Variable 呼び出しによって直接作成された重みはキャプチャしません。このセクションでは、これらのネストされたケースを処理する方法について説明します。

(既存の使用法)tf.keras.layers および tf.keras.models

ネストされた Keras レイヤーとモデルの既存の使用法については、tf.compat.v1.keras.utils.get_or_create_layer を使用します。これは、既存の TF1.x にネストされた Keras の使用法の移行を容易にするためにのみ推奨されます。tf.Variable および tf.Module の新しいコードは、以下で説明するように明示的な属性設定を使用する必要があります。

tf.compat.v1.keras.utils.get_or_create_layer を使用するには、ネストされたモデルを構築するコードをメソッドにラップし、メソッドに渡します。以下に例を示します。

class NestedModel(tf.keras.Model):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  def build_model(self):
    inp = tf.keras.Input(shape=(5, 5))
    dense_layer = tf.keras.layers.Dense(
        10, name="dense", kernel_regularizer="l2",
        kernel_initializer=tf.compat.v1.ones_initializer())
    model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
    return model

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    # Get or create a nested model without assigning it as an explicit property
    model = tf.compat.v1.keras.utils.get_or_create_layer(
        "dense_model", self.build_model)
    return model(inputs)

layer = NestedModel(10)
layer(tf.ones(shape=(5,5)))
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.]], dtype=float32)>

このメソッドにより、これらのネストされたレイヤーが正しく再利用され、TensorFlow によって追跡されることが保証されます。適切なメソッドでは @track_tf1_style_variables デコレータが引き続き必要であることに注意してください。get_or_create_layer に渡されるモデルビルダーメソッド(この場合は self.build_model)は、引数を取りません。

重みが追跡されます。

assert len(layer.weights) == 2
weights = {x.name: x for x in layer.variables}

assert set(weights.keys()) == {"dense/bias:0", "dense/kernel:0"}

layer.weights
[<tf.Variable 'dense/kernel:0' shape=(5, 10) dtype=float32, numpy=
 array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]

正則化損失も同様に追跡されます。

tf.add_n(layer.losses)
<tf.Tensor: shape=(), dtype=float32, numpy=0.5>

増分移行: tf.Variables および tf.Modules

デコレートされたメソッドに tf.Variable 呼び出し、または tf.Module を埋め込む必要がある場合(たとえば、このガイドの後半で説明されている非レガシー TF2 API への段階的な移行に従っている場合など)、次の要件に従って、これらを明示的に追跡する必要があります。

  • 変数/モジュール/レイヤーが一度だけ作成されることを明示的に確認する
  • 典型的なモジュールまたはレイヤーを定義するときと同じように、それらをインスタンス属性として明示的に添付する
  • 後続の呼び出しで、作成済みのオブジェクトを明示的に再利用する

これにより、重みが呼び出しごとに新しく作成されず、正しく再利用されることが保証されます。さらに、既存の重みと正則化の損失が追跡されることも保証されます。

以下に例を示します。

class NestedLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def __call__(self, inputs):
    out = inputs
    with tf.compat.v1.variable_scope("inner_dense"):
      # The weights are created with a `regularizer`,
      # so the layer should track their regularization losses
      kernel = tf.compat.v1.get_variable(
          shape=[out.shape[-1], self.units],
          regularizer=tf.keras.regularizers.L2(),
          initializer=tf.compat.v1.initializers.glorot_normal,
          name="kernel")
      bias = tf.compat.v1.get_variable(
          shape=[self.units,],
          initializer=tf.compat.v1.initializers.zeros,
          name="bias")
      out = tf.linalg.matmul(out, kernel)
      out = tf.compat.v1.nn.bias_add(out, bias)
    return out

class WrappedDenseLayer(tf.keras.layers.Layer):

  def __init__(self, units, **kwargs):
    super().__init__(**kwargs)
    self.units = units
    # Only create the nested tf.variable/module/layer/model
    # once, and then reuse it each time!
    self._dense_layer = NestedLayer(self.units)

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('outer'):
      outputs = tf.compat.v1.layers.dense(inputs, 3)
      outputs = tf.compat.v1.layers.dense(inputs, 4)
      return self._dense_layer(outputs)

layer = WrappedDenseLayer(10)

layer(tf.ones(shape=(5, 5)))
/tmpfs/tmp/ipykernel_203546/2765428776.py:38: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  outputs = tf.compat.v1.layers.dense(inputs, 3)
/tmpfs/tmp/ipykernel_203546/2765428776.py:39: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  outputs = tf.compat.v1.layers.dense(inputs, 4)
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[-0.11340484,  0.5805201 ,  0.86871445,  1.8958894 ,  0.2752131 ,
        -0.01736549,  1.656469  ,  1.6632295 , -0.00841344, -1.8026792 ],
       [-0.11340484,  0.5805201 ,  0.86871445,  1.8958894 ,  0.2752131 ,
        -0.01736549,  1.656469  ,  1.6632295 , -0.00841344, -1.8026792 ],
       [-0.11340484,  0.5805201 ,  0.86871445,  1.8958894 ,  0.2752131 ,
        -0.01736549,  1.656469  ,  1.6632295 , -0.00841344, -1.8026792 ],
       [-0.11340484,  0.5805201 ,  0.86871445,  1.8958894 ,  0.2752131 ,
        -0.01736549,  1.656469  ,  1.6632295 , -0.00841344, -1.8026792 ],
       [-0.11340484,  0.5805201 ,  0.86871445,  1.8958894 ,  0.2752131 ,
        -0.01736549,  1.656469  ,  1.6632295 , -0.00841344, -1.8026792 ]],
      dtype=float32)>

track_tf1_style_variables デコレータで装飾されている場合でも、ネストされたモジュールを明示的に追跡する必要があることに注意してください。デコレートされたメソッドを持つ各モジュール/レイヤーには、それに関連付けられた独自の変数ストアがあるためです。

重みは正しく追跡されます。

assert len(layer.weights) == 6
weights = {x.name: x for x in layer.variables}

assert set(weights.keys()) == {"outer/inner_dense/bias:0",
                               "outer/inner_dense/kernel:0",
                               "outer/dense/bias:0",
                               "outer/dense/kernel:0",
                               "outer/dense_1/bias:0",
                               "outer/dense_1/kernel:0"}

layer.trainable_weights
[<tf.Variable 'outer/inner_dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/inner_dense/kernel:0' shape=(4, 10) dtype=float32, numpy=
 array([[ 0.3075705 , -0.40741736, -0.39825478,  0.6889204 ,  0.6287682 ,
          0.2633951 ,  0.11980183, -0.18669793, -0.27482644,  0.22058026],
        [-0.4075396 ,  0.03225906, -0.3195349 ,  0.02189444, -0.00137738,
         -0.26833618,  0.16752942,  0.6268268 , -0.11936384, -0.551885  ],
        [-0.61430264, -0.64841527, -0.67951083, -0.43272364, -0.19525278,
          0.1670384 , -0.49141192, -0.32757416, -0.6216911 ,  0.43401024],
        [-0.2758267 , -0.124303  ,  0.4036006 ,  0.8149748 , -0.09092603,
          0.35782033,  0.54953164,  0.29575136, -0.39113423, -0.38115194]],
       dtype=float32)>,
 <tf.Variable 'outer/dense/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/dense/kernel:0' shape=(5, 3) dtype=float32, numpy=
 array([[ 0.06724197,  0.12351465, -0.6902447 ],
        [ 0.623815  , -0.11178988, -0.07436419],
        [ 0.14670032,  0.54447657,  0.7763539 ],
        [ 0.41450244, -0.5223446 , -0.16240823],
        [ 0.19045538,  0.8633928 , -0.46085978]], dtype=float32)>,
 <tf.Variable 'outer/dense_1/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/dense_1/kernel:0' shape=(5, 4) dtype=float32, numpy=
 array([[-0.54109144,  0.01866847, -0.15464032, -0.0679152 ],
        [ 0.80793977,  0.33000815, -0.15994817,  0.5561161 ],
        [-0.20882791, -0.05373597, -0.38073334,  0.22148895],
        [-0.19787806,  0.66540897, -0.44353148,  0.03214008],
        [ 0.39443994,  0.44928133, -0.11598724,  0.6651341 ]],
       dtype=float32)>]

正則化損失も同様に追跡されます。

layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.067925505>]

NestedLayer が Keras 以外の tf.Module である場合、変数は引き続き追跡されますが、正則化の損失は自動的に追跡されないため、別に明示的に追跡する必要があります。

変数名に関するガイダンス

明示的な tf.Variable 呼び出しと Keras レイヤーは、馴染まれている get_variablevariable_scopes の組み合わせとは異なるレイヤー名/変数名自動生成メカニズムを使用します。TF1.x グラフから TF2 の Eager execution と tf.function に移行する場合でも、Shim は、変数名を get_variable により作成された変数と一致させますが、tf.Variable 呼び出しのために生成された変数名と、メソッドデコレータ内に埋め込む Keras レイヤーでは同じことを保証することはできません。TF2 の Eager execution と tf.function では、複数の変数が同じ名前を共有することも可能です。

このガイドで後述する、正確性の検証と TF1.x チェックポイントのマッピングに関するセクションに従うときは、このことに特に注意する必要があります。

デコレートされたメソッドで tf.compat.v1.make_template を使用する

tf.compat.v1.make_template を使用する代わりに、TF2 の上の薄いレイヤーであるtf.compat.v1.keras.utils.track_tf1_style_variables を直接使用することを強くお勧めします

すでに tf.compat.v1.make_template に依存する以前の TF1.x コードについては、このセクションのガイダンスに従ってください。

tf.compat.v1.make_templateget_variable を使用するコードをラップするため、track_tf1_style_variables デコレータを使用すると、レイヤー呼び出しでこれらのテンプレートを使用して、重みと正則化損失を正常に追跡できます。

ただし、必ず make_template を 1 回だけ呼び出してから、各レイヤー呼び出しで同じテンプレートを再利用してください。そうしないと、新しい変数セットとともにレイヤーを呼び出すたびに、新しいテンプレートが作成されます。

以下に例を示します。

class CompatV1TemplateScaleByY(tf.keras.layers.Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    def my_op(x, scalar_name):
      var1 = tf.compat.v1.get_variable(scalar_name,
                            shape=[],
                            regularizer=tf.compat.v1.keras.regularizers.L2(),
                            initializer=tf.compat.v1.constant_initializer(1.5))
      return x * var1
    self.scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op, scalar_name='y')

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('layer'):
      # Using a scope ensures the `scale_by_y` name will not be incremented
      # for each instantiation of the layer.
      return self.scale_by_y(inputs)

layer = CompatV1TemplateScaleByY()

out = layer(tf.ones(shape=(2, 3)))
print("weights:", layer.weights)
print("regularization loss:", layer.losses)
print("output:", out)
weights: [<tf.Variable 'layer/scale_by_y/y:0' shape=() dtype=float32, numpy=1.5>]
regularization loss: [<tf.Tensor: shape=(), dtype=float32, numpy=0.022499999>]
output: tf.Tensor(
[[1.5 1.5 1.5]
 [1.5 1.5 1.5]], shape=(2, 3), dtype=float32)

警告: make_template で作成された同じテンプレートを複数のレイヤーインスタンスで共有しないでください。Shim デコレータの変数および正則化損失追跡メカニズムが壊れる可能性があるためです。さらに、複数のレイヤーインスタンス内で同じ make_template 名を使用する場合は、作成したテンプレートの使用法を variable_scope 内にネストする必要があります。そうでない場合、テンプレートの variable_scope の生成された名前は、レイヤーの新しいインスタンスごとに増加します。これにより、重みの名前が予期しない方法で変更される可能性があります。

ネイティブ TF2 への段階的な移行

前述のように、track_tf1_style_variables を使用すると、TF2 スタイルのオブジェクト指向の tf.Variable/tf.keras.layers.Layer/tf.Module の使用法と同じデコレートされたモジュール/レイヤー内の従来の tf.compat.v1.get_variable/tf.compat.v1.layers スタイルの使用法を併用できます。

これは、TF1.x モデルを TF2 と完全に互換性のあるものにした後、すべての新しいモデルコンポーネントをネイティブ(非 tf.compat.v1)TF2 API で記述し、以前のコードと相互運用できることを意味します。

ただし、以前モデルコンポーネントを変更し続ける場合は、従来のスタイルの tf.compat.v1 の使用法を新しく記述された TF2 コードに推奨される純粋にネイティブなオブジェクト指向 API に段階的に切り替えることもできます。

tf.compat.v1.get_variable の使用法は、Keras レイヤー/モデルをデコレートしている場合は self.add_weight 呼び出し、Keras オブジェクトまたは tf.Module をデコレートしている場合は tf.Variable 呼び出しに置き換えることがます。

関数型とオブジェクト指向の両方の tf.compat.v1.layers は、通常、引数を変更することなく、同等の tf.keras.layers レイヤーに置き換えることができます。

track_tf1_style_variables を使用する純粋なネイティブ API への段階的な移行中に、モデルの一部または共通パターンを個々のレイヤー/モジュールに含めることもできます。

Slim と contrib.layers に関する注意

以前の TF 1.x コードの多くは、TF 1.x に tf.contrib.layers としてパッケージ化された Slim ライブラリを使用しています。Slim を使用したコードをネイティブ TF 2 に変換することは、v1.layers を変換するよりも複雑です。そのため、まず Slim コードを v1.layers に変換してから、Keras に変換する方が理にかなっています。以下は、Slim コードを変換するための一般的なガイダンスです。

  • すべての引数が明示的であることを確認します。可能であれば arg_scopes を削除します。使用する必要がある場合は、normalizer_fnactivation_fn をそれぞれのレイヤーに分割します。
  • 分離可能な畳み込みレイヤーは 1 つまたはそれ以上の異なる Keras レイヤー(深さ、点、分離可能な Keras レイヤー)にマップします。
  • Slim と v1.layers には異なる引数名とデフォルト値があります。
  • 一部の引数には異なるスケールがあります。

チェックポイントの互換性を無視したネイティブ TF2 への移行

次のコードサンプルは、チェックポイントの互換性を考慮せずにモデルを純粋なネイティブ API に段階的に移行する方法を示しています。

class CompatModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dropout(out, training=training)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

次に、compat.v1 API をネイティブオブジェクト指向の同等のものに部分的に置き換えます。レイヤーコンストラクタで作成された Keras オブジェクトに畳み込みレイヤーを切り替えることから始めます。

class PartiallyMigratedModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_layer(inputs)
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dropout(out, training=training)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

v1.keras.utils.DeterministicRandomTestTool クラスを使用して、この段階的な変更によってモデルが以前と同じ動作をすることを確認します。

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = CompatModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  original_output = layer(inputs)

  # Grab the regularization loss as well
  original_regularization_loss = tf.math.add_n(layer.losses)

print(original_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
/tmpfs/tmp/ipykernel_203546/355611412.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_203546/355611412.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_203546/355611412.py:14: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  out = tf.compat.v1.layers.dropout(out, training=training)
/tmpfs/tmp/ipykernel_203546/355611412.py:15: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = PartiallyMigratedModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
/tmpfs/tmp/ipykernel_203546/3237389364.py:14: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_203546/3237389364.py:15: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  out = tf.compat.v1.layers.dropout(out, training=training)
/tmpfs/tmp/ipykernel_203546/3237389364.py:16: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

個々の compat.v1.layers をすべてネイティブ Keras レイヤーに置き換えました。

class NearlyFullyNativeModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")
    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_layer = tf.keras.layers.Dense(
      self.units,
      kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_layer(inputs)
      out = self.flatten_layer(out)
      out = self.dense_layer(out)
      return out
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = NearlyFullyNativeModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

最後に、残りの(不要になった)variable_scope の使用と track_tf1_style_variables デコレータ自体を削除します。

完全にネイティブ API を使用するモデルのバージョンになりました。

class FullyNativeModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")
    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_layer = tf.keras.layers.Dense(
      self.units,
      kernel_regularizer="l2")

  def call(self, inputs):
    out = self.conv_layer(inputs)
    out = self.flatten_layer(out)
    out = self.dense_layer(out)
    return out
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = FullyNativeModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

ネイティブ TF2 への移行中にチェックポイントの互換性を維持する

上記のネイティブ TF2 API への移行プロセスでは、変数名(Keras API は非常に異なる重み名を生成するため)と、モデル内の異なる重みを指すオブジェクト指向パスの両方が変更されました。これらの変更により、既存の TF1 スタイルの名前ベースのチェックポイントと TF2 スタイルのオブジェクト指向チェックポイントの両方が機能しなくなります。

ただし、場合によっては、元の名前ベースのチェックポイントを使用して、TF1.x チェックポイントの再利用ガイドで詳しく説明されているようなアプローチを使用して、新しい名前への変数のマッピングを見つけられる場合があります。

以下にヒントを示します。

  • 変数はすべて設定が可能な name 引数を持ちます。
  • また、Keras モデルは name 引数を取り、それらの変数のためのプレフィックスとして設定されます。
  • v1.name_scope 関数は、変数名のプレフィックスの設定に使用できます。これは tf.variable_scope とは大きく異なります。これは名前だけに影響するもので、変数と再利用の追跡はしません。

上記を念頭に置いて、次のコードサンプルは、チェックポイントを同時に更新しながら、モデルの一部を段階的に更新するためにコードに適応できるワークフローを示しています。

注意: Keras レイヤーでの変数の命名は複雑であるため、これがすべてのユースケースで機能するとは限りません。

  1. まず、関数型の tf.compat.v1.layers をオブジェクト指向のバージョンに切り替えます。
class FunctionalStyleCompatModel(tf.keras.layers.Layer):

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.conv2d(
          out, 4, 4,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = FunctionalStyleCompatModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_203546/1716504801.py:6: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_203546/1716504801.py:9: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_203546/1716504801.py:12: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['model/conv2d/bias:0',
 'model/conv2d/kernel:0',
 'model/conv2d_1/bias:0',
 'model/conv2d_1/kernel:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']
  1. 次に、compat.v1.layer オブジェクトと compat.v1.get_variable によって作成された変数をメソッドが track_tf1_style_variables でデコレートされた tf.keras.layers.Layer/tf.Module オブジェクトのプロパティとして割り当てます。(オブジェクト指向の TF2 スタイルのチェックポイントは、変数名によるパスと新しいオブジェクト指向のパスの両方を保存することに注意してください)。
class OOStyleCompatModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.compat.v1.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.compat.v1.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      out = self.conv_2(out)
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = OOStyleCompatModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_203546/1693875107.py:17: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['model/conv2d/kernel:0',
 'model/conv2d/bias:0',
 'model/conv2d_1/kernel:0',
 'model/conv2d_1/bias:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']
  1. この時点で読み込まれたチェックポイントを再保存して、変数名(compat.v1.layers の場合)またはオブジェクト指向オブジェクトグラフの両方でパスを保存します。
weights = {v.name: v for v in layer.weights}
assert weights['model/conv2d/kernel:0'] is layer.conv_1.kernel
assert weights['model/conv2d_1/bias:0'] is layer.conv_2.bias
  1. オブジェクト指向の compat.v1.layers をネイティブ Keras レイヤーに置き換えながら、最近保存したチェックポイントを読み込めるようになりました。置き換えられたレイヤーの自動生成された variable_scopes を引き続き記録することにより、残りの compat.v1.layers の変数名を確実に保持します。これらの置き換えられたレイヤー/変数は、変数名パスの代わりに、チェックポイント内の変数へのオブジェクト属性パスのみを使用するようになりました。

一般に、プロパティにアタッチされた変数での compat.v1.get_variable の使用を次のように置き換えることができます。

  • tf.Variable を使用するように置き換えます。または
  • tf.keras.layers.Layer.add_weight を使用してそれらを更新します。一度にすべてのレイヤーを切り替えない場合、name 引数がない残りの compat.v1.layers の自動生成されたレイヤー/変数の命名が変更される可能性があることに注意してください。その場合、削除された compat.v1.layer の生成されたスコープ名に対応する variable_scope を手動で開閉し、残りの compat.v1.layers の変数名を同じにしておく必要があります。そうしないと、既存のチェックポイントからのパスが競合する可能性があり、チェックポイントの読み込みが正しく動作しなくなります。
def record_scope(scope_name):
  """Record a variable_scope to make sure future ones get incremented."""
  with tf.compat.v1.variable_scope(scope_name):
    pass

class PartiallyNativeKerasLayersModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.keras.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.keras.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      record_scope('conv2d') # Only needed if follow-on compat.v1.layers do not pass a `name` arg
      out = self.conv_2(out)
      record_scope('conv2d_1') # Only needed if follow-on compat.v1.layers do not pass a `name` arg
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = PartiallyNativeKerasLayersModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_203546/3143218429.py:24: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['partially_native_keras_layers_model/model/conv2d_13/kernel:0',
 'partially_native_keras_layers_model/model/conv2d_13/bias:0',
 'partially_native_keras_layers_model/model/conv2d_14/kernel:0',
 'partially_native_keras_layers_model/model/conv2d_14/bias:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']

変数を構築した後、このステップでチェックポイントを保存すると、最新の利用可能なオブジェクトパスのみが含まれるようになります。

削除された compat.v1.layers のスコープを記録して、残りの compat.v1.layers の自動生成された重み名を保持します。

weights = set(v.name for v in layer.weights)
assert 'model/conv2d_2/kernel:0' in weights
assert 'model/conv2d_2/bias:0' in weights
  1. モデル内のすべての compat.v1.layerscompat.v1.get_variable を完全にネイティブな同等のものに置き換えるまで、上記の手順を繰り返します。
class FullyNativeKerasLayersModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.keras.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.keras.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")
    self.conv_3 = tf.keras.layers.Conv2D(
          5, 5,
          kernel_regularizer="l2")


  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      out = self.conv_2(out)
      out = self.conv_3(out)
      return out

layer = FullyNativeKerasLayersModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
['fully_native_keras_layers_model/model/conv2d_16/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_16/bias:0',
 'fully_native_keras_layers_model/model/conv2d_17/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_17/bias:0',
 'fully_native_keras_layers_model/model/conv2d_18/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_18/bias:0']

新しく更新されたチェックポイントが期待どおりに動作することを確認するためにテストすることを忘れないでください。このプロセスの段階的なステップごとに、数値の正確性を検証するガイドに記載されている手法を適用して、移行したコードが正しく実行されるようにします。

モデリング Shim で処理されない TF1.x から TF2 への動作変更

このガイドで説明されているモデリング Shim を使用することにより、Eager execution と tf.function を使用する際にコレクションに依存することなく、get_variabletf.compat.v1.layers、および variable_scope で作成された変数、レイヤー、正則化の損失のセマンティクスが以前と同様に機能できるようになります。

これは、モデルのフォワードパスが依存している可能性があるすべての TF1.x 固有のセマンティクスを網羅しているわけではありません。場合によっては、モデルのフォワードパスを TF2 で単独で実行するには、Shim が不十分な場合があります。TF1.x と TF2 の動作の違いについて詳しくは、TF1.x と TF2 の動作ガイドを参照してください。