モデルチェックポイントの移行

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

注意: tf.compat.v1.Saver で保存されたチェックポイントは、多くの場合、TF1 または名前ベースのチェックポイントと呼ばれます。tf.train.Checkpoint で保存されたチェックポイントは、TF2 またはオブジェクトベースのチェックポイントと呼ばれます。

概要

このガイドでは、tf.compat.v1.Saver を使用してチェックポイントを保存および読み込むモデルがあり、TF2 tf.train.Checkpoint API を使用してコードを移行するか、TF2 モデルで既存のチェックポイントを使用する方法を実演します。

以下に、一般的なシナリオをいくつか示します。

シナリオ 1

以前に実行したトレーニングからの既存の TF1 チェックポイントを TF2 に読み込むまたは変換する必要があります。

シナリオ 2

モデルを調整する際に変数名とパスを変更するリスクがある場合(get_variable から明示的な tf.Variable の作成に段階的に移行する場合など)、途中で既存のチェックポイントの保存/読み込みを維持したいと考えています。

モデルの移行中にチェックポイントの互換性を維持する方法のセクションを参照してください。

シナリオ 3

トレーニングコードとチェックポイントを TF2 に移行していますが、推論パイプラインには引き続き TF1 チェックポイントが必要です(本番環境の安定性のため)。

オプション 1

トレーニング時に TF1 と TF2 の両方のチェックポイントを保存します。

オプション 2

TF2 チェックポイントを TF1 に変換します。


以下の例は、モデルの移行方法を柔軟に決定できるように TF1/TF2 でのチェックポイントの保存と読み込みのすべての組み合わせを示しています。

セットアップ

import tensorflow as tf
import tensorflow.compat.v1 as tf1

def print_checkpoint(save_path):
  reader = tf.train.load_checkpoint(save_path)
  shapes = reader.get_variable_to_shape_map()
  dtypes = reader.get_variable_to_dtype_map()
  print(f"Checkpoint at '{save_path}':")
  for key in shapes:
    print(f"  (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, "
          f"value={reader.get_tensor(key)})")
2022-12-14 22:28:44.682638: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:28:44.682724: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:28:44.682733: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

TF1 から TF2 への変更

このセクションは、TF1 と TF2 の間で何が変更されたか、および「名前ベース」(TF1)と「オブジェクトベース」(TF2)のチェックポイントの意味について説明します。

2 種類のチェックポイントは、実際には同じ形式(基本的にはキーと値の表)で保存されます。違いは、キーの生成方法にあります。

名前ベースのチェックポイントのキーは、変数の名前です。オブジェクトベースのチェックポイントのキーは、ルートオブジェクトから変数へのパスを参照します(以下の例は、これが何を意味するかをよりよく理解するのに役立ちます)。

まず、いくつかのチェックポイントを保存します。

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    saver.save(sess, 'tf1-ckpt')

print_checkpoint('tf1-ckpt')
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(7.0, name='c')

ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save('tf2-ckpt')
print_checkpoint(save_path_v2)
Checkpoint at 'tf2-ckpt-1':
  (key='variables/2/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=7.0)
  (key='variables/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=6.0)
  (key='variables/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=5.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n%\n\r\x08\x01\x12\tvariables\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n\x19\n\x05\x08\x03\x12\x010\n\x05\x08\x04\x12\x011\n\x05\x08\x05\x12\x012*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nA\x12;\n\x0eVARIABLE_VALUE\x12\x01a\x1a&variables/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nA\x12;\n\x0eVARIABLE_VALUE\x12\x01b\x1a&variables/1/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nH\x12B\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a&variables/2/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

tf2-ckpt のキーを見ると、それらはすべて各変数のオブジェクトパスを参照しています。たとえば、変数 avariables リストの最初の要素であるため、そのキーは variables/0/... になります (.ATTRIBUTES/VARIABLE_VALUE 定数は無視できます)。

以下では Checkpoint オブジェクトを詳しく見てみます。

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])
root type = Checkpoint
root.variables = ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>])
root.variables[0] = <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>

以下のスニペットを試してみて、オブジェクト構造によってチェックポイントキーがどのように変化するかを確認してください。

module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, 
                                c=c,
                                module=module)
test_ckpt_path = test_ckpt.save('root-tf2-ckpt')
print_checkpoint(test_ckpt_path)
Checkpoint at 'root-tf2-ckpt-1':
  (key='v/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='v/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='module/d/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='c/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n0\n\x05\x08\x01\x12\x01c\n\n\x08\x02\x12\x06module\n\x05\x08\x03\x12\x01v\n\x10\x08\x04\x12\x0csave_counter*\x02\x08\x01\n>\x128\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1cc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n\x0b\n\x05\x08\x05\x12\x01d*\x02\x08\x01\n\x12\n\x05\x08\x06\x12\x01a\n\x05\x08\x07\x12\x01b*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nE\x12?\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a#module/d/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

なぜ TF2 はこのメカニズムを使用するのでしょうか。

TF2 にはグローバルグラフがないため、変数名は信頼できず、プログラム間で矛盾する可能性があります。TF2 は、変数がレイヤーによって所有され、レイヤーがモデルによって所有されるオブジェクト指向モデリングアプローチを推奨します。

variable = tf.Variable(...)
layer.variable_name = variable
model.layer_name = layer

モデルの移行中にチェックポイントの互換性を維持する方法

移行プロセスの重要なステップの 1 つは、すべての変数が正しい値に初期化されていることを確認することです。これにより、演算や関数が正しい計算を行っていることを検証できます。そのためには、移行のさまざまな段階でモデル間のチェックポイントの互換性を考慮する必要があります。基本的に、このセクションでは、モデルを変更しながら同じチェックポイントを使い続けるにはどうすればよいかという質問に答えます。

以下に、柔軟性を高めるために、チェックポイントの互換性を維持する 3 つの方法を示します。

  1. モデルには以前と同じ変数名があります。
  2. モデルにはさまざまな変数名があり、チェックポイント内の変数名を新しい名前にマッピングする割り当てマップを維持します。
  3. モデルにはさまざまな変数名があり、すべての変数を格納する TF2 チェックポイントオブジェクトを維持しています。

変数名が一致する場合

長いタイトル: 変数名が一致する場合にチェックポイントを再利用する方法。

簡単な答え: tf1.train.Saver または tf.train.Checkpoint のいずれかを使用して、既存のチェックポイントを直接読み込むことができます。


tf.compat.v1.keras.utils.track_tf1_style_variables を使用するとモデル変数名が以前と同じであることを保証できます。また、変数名が一致することを手動で確認することもできます。

移行されたモデルで変数名が一致する場合、tf.train.Checkpoint または tf.compat.v1.train.Saver のいずれかを直接使用してチェックポイントを読み込めます。どちらの API も Eager モードと Graph モードと互換性があるため、移行のどの段階でも使用できます。

注意: tf.train.Checkpoint を使用して TF1 チェックポイントを読み込むことはできますが、tf.compat.v1.Saver を使用して TF2 チェックポイントを読み込むには複雑な名前の照合が必要です。

以下は、異なるモデルで同じチェックポイントを使用する例です。 まず、TF1 チェックポイントを tf1.train.Saver で保存します。

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    save_path = saver.save(sess, 'tf1-ckpt')
print_checkpoint(save_path)
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

以下の例では、tf.compat.v1.Saver を使用して、Eager モードでチェックポイントを読み込みます。

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# With the removal of collections in TF2, you must pass in the list of variables
# to the Saver object:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=save_path)
print(f"loaded values of [a, b, c]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}]")

# Saving also works in eager (sess must be None).
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt
loaded values of [a, b, c]:  [1.0, 2.0, 3.0]
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

次のスニペットは、TF2 API tf.train.Checkpoint を使用してチェックポイントを読み込みます。

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0.0, name='scoped/c')

print("Variable names: ")
print(f"  a.name = {a.name}")
print(f"  b.name = {b.name}")
print(f"  c.name = {c.name}")
print(f"  c_2.name = {c_2.name}")

# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(save_path)
print(f"loaded values of [a, b, c, c_2]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")
Variable names: 
  a.name = a:0
  b.name = b:0
  c.name = scoped/c:0
  c_2.name = scoped/c:0
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/checkpoint.py:1473: NameBasedSaverStatus.__init__ (from tensorflow.python.checkpoint.checkpoint) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
loaded values of [a, b, c, c_2]:  [1.0, 2.0, 3.0, 3.0]

TF2 の変数名

  • 変数はすべて設定が可能な name 引数を持ちます。
  • また、Keras モデルは name 引数を取り、それらの変数のためのプレフィックスとして設定されます。
  • v1.name_scope 関数は、変数名のプレフィックスの設定に使用できます。これは tf.variable_scope とは大きく異なります。これは名前だけに影響するもので、変数と再利用の追跡はしません。

tf.compat.v1.keras.utils.track_tf1_style_variables デコレータは、tf.variable_scopetf.compat.v1.get_variable の命名と再利用のセマンティクスを変更せずに維持し、変数名と TF1 チェックポイントの互換性を維持するのに役立つ shim です。詳細については、モデルマッピングガイドを参照してください。

注意 1: shim を使用している場合は、TF2 API を使用してチェックポイントを読み込みます(事前トレーニング済みの TF1 チェックポイントを使用する場合でも)。

Keras のチェックポイントのセクションを参照してください。

注意 2: get_variable から tf.Variable に移行する場合:

shim でデコレートされたレイヤーまたはモジュールが、tf.compat.v1.get_variable の代わりに tf.Variable を使用するいくつかの変数(または Keras レイヤー/モデル)で構成されていて、プロパティとしてアタッチされる場合やオブジェクト指向の方法で追跡される場合、TF1.x グラフ/セッションと Eager execution 実行時では、変数の命名セマンティクスが異なる場合があります。

つまり、TF2 で実行すると、名前が期待どおりにならない可能性があります

警告: 名前ベースのチェックポイント内の複数の変数を同じ名前にマップする必要がある場合、問題が発生する可能性があります。tf.name_scope とレイヤー コンストラクタまたは tf.Variable name 引数を使用して変数名を調整することで、レイヤーと変数の名前を明示的に調整し、重複がないことを確認できるかもしれません。

割り当てマップの維持

割り当てマップは、一般に TF1 モデル間で重みを転送するために使用され、モデルの移行中に変数名が変更された場合にも使用できます。

これらのマップを使用すると tf.compat.v1.train.init_from_checkpointtf.compat.v1.train.Saver、および tf.train.load_checkpoint を使用して、変数またはスコープ名が変更されている可能性があるモデルに重みを読み込めます。

このセクションの例では、以前に保存したチェックポイントを使用します。

print_checkpoint('tf1-ckpt')
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

init_from checkpoint で読み込む

tf1.train.init_from_checkpoint は、割り当て演算を作成する代わりに変数イニシャライザに値を配置するため、グラフ/セッション内で呼び出す必要があります。

assignment_map 引数を使用して、変数を読み込む方法を構成します。ドキュメントから以下を実行します。

割り当てマップは、次の構文をサポートしています。

  • 'checkpoint_scope_name/': 'scope_name/' - テンソル名が一致する checkpoint_scope_name から最新の scope_name 内のすべての変数を読み込みます。
  • 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - checkpoint_scope_name/some_other_variable から scope_name/variable_name 変数を初期化します。
  • 'scope_variable_name': variable - 指定された tf.Variable オブジェクトをチェックポイントからのテンソル 'scope_variable_name' で初期化します。
  • 'scope_variable_name': list(variable) - チェックポイントからテンソル 'scope_variable_name' を使用して、分割された変数のリストを初期化します。
  • '/': 'scope_name/' - 最新の scope_name 内のすべての変数をチェックポイントのルートから読み込みます(例: スコープなし)。
# Restoring with tf1.train.init_from_checkpoint:

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # The assignment map will remap all variables in the checkpoint to the
    # new scope:
    tf1.train.init_from_checkpoint(
        'tf1-ckpt',
        assignment_map={'/': 'new_scope/'})
    # `init_from_checkpoint` adds the initializers to these variables.
    # Use `sess.run` to run these initializers.
    sess.run(tf1.global_variables_initializer())

    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Restored [a, b, c]:  [1.0, 2.0, 3.0]

tf1.train.Saver で読み込む

init_from_checkpoint とは異なり、tf.compat.v1.train.Saver は Graph モードと Eager モードの両方で実行できます。var_list 引数はオプションでディクショナリを受け入れますが、変数名を tf.Variable オブジェクトにマップする必要があります。

# Restoring with tf1.train.Saver (works in both graph and eager):

# A new model with a different scope for the variables.
with tf1.variable_scope('new_scope'):
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
# Initialize the saver with a dictionary with the original variable names:
saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})
saver.restore(sess=None, save_path='tf1-ckpt')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt
Restored [a, b, c]:  [1.0, 2.0, 3.0]

tf.train.load_checkpoint で読み込む

このオプションは、変数値を正確に制御する必要がある場合に適しています。繰り返しますが、これは Graph モードと Eager モードの両方で機能します。

# Restoring with tf.train.load_checkpoint (works in both graph and eager):

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # It may be easier writing a loop if your model has a lot of variables.
    reader = tf.train.load_checkpoint('tf1-ckpt')
    sess.run(a.assign(reader.get_tensor('a')))
    sess.run(b.assign(reader.get_tensor('b')))
    sess.run(c.assign(reader.get_tensor('scoped/c')))
    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Restored [a, b, c]:  [1.0, 2.0, 3.0]

TF2 チェックポイントオブジェクトの維持

移行中に変数名とスコープ名が大幅に変更される可能性がある場合は、tf.train.Checkpoint と TF2 チェックポイントを使用してください。TF2 は、変数名の代わりにオブジェクト構造を使用します(詳細については、TF1 から TF2 への変更を参照してください)。

つまり、チェックポイントを保存または復元する tf.train.Checkpoint を作成するときは、同じ順序(リストの場合)とキーを使用するようにしてください。(Checkpoint イニシャライザへのディクショナリとキーワード引数)。以下にチェックポイントの互換性の例を示します。

ckpt = tf.train.Checkpoint(foo=[var_a, var_b])

# compatible with ckpt
tf.train.Checkpoint(foo=[var_a, var_b])

# not compatible with ckpt
tf.train.Checkpoint(foo=[var_b, var_a])
tf.train.Checkpoint(bar=[var_a, var_b])

以下のコードサンプルは、「同じ」tf.train.Checkpoint を使用して異なる名前の変数を読み込む方法を示しています。まず、TF2 チェックポイントを保存します。

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("[a, b, c]: ", sess.run([a, b, c]))

    # Save a TF2 checkpoint
    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    tf2_ckpt_path = ckpt.save('tf2-ckpt')
    print_checkpoint(tf2_ckpt_path)
[a, b, c]:  [1.0, 2.0, 3.0]
Checkpoint at 'tf2-ckpt-1':
  (key='unscoped/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='unscoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='scoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n0\n\n\x08\x01\x12\x06scoped\n\x0c\x08\x02\x12\x08unscoped\n\x10\x08\x03\x12\x0csave_counter*\x02\x08\x01\n\x0b\n\x05\x08\x04\x12\x010*\x02\x08\x01\n\x12\n\x05\x08\x05\x12\x010\n\x05\x08\x06\x12\x011*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nE\x12?\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a#scoped/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%unscoped/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%unscoped/1/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

変数やスコープ名が変更しても tf.train.Checkpoint を引き続き使用できます。

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.variable_scope('different_scope'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))

    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    # `assert_consumed` validates that all checkpoint objects are restored from
    # the checkpoint. `run_restore_ops` is required when running in a TF1
    # session.
    ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()

    # Removing `assert_consumed` is fine if you want to skip the validation.
    # ckpt.restore(tf2_ckpt_path).run_restore_ops()

    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

Eager モード:

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

# The keys "scoped" and "unscoped" are no longer relevant, but are used to
# maintain compatibility with the saved checkpoints.
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])

ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

Estimator の TF2 チェックポイント

上記のセクションでは、モデルの移行中にチェックポイントの互換性を維持する方法について説明しました。これらの概念は、Estimator モデルにも適用されますが、チェックポイントの保存/読み込み方法は少し異なります。Estimator モデルを移行して TF2 API を使用する場合、モデルがまだ Estimator を使用している間に、TF1 チェックポイントから TF2 チェックポイントに切り替えたい場合があります。このセクションでは、その方法を示します。

tf.estimator.EstimatorMonitoredSession には、scaffold と呼ばれる保存メカニズムがあります。これは、tf.compat.v1.train.Scaffold オブジェクトです。Scaffold には、TF1 または TF2 スタイルのチェックポイントを保存するための EstimatorMonitoredSession が含まれていることがあります。

# A model_fn that saves a TF1 checkpoint
def model_fn_tf1_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=None
  )

!rm -rf est-tf1
est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf1')
print_checkpoint(latest_checkpoint)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'est-tf1', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into est-tf1/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into est-tf1/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Loss for final step: 1.0.
Checkpoint at 'est-tf1/model.ckpt-1':
  (key='var', shape=[], dtype=float32, value=2.0)
  (key='global_step', shape=[], dtype=int64, value=1)
# A model_fn that saves a TF2 checkpoint
def model_fn_tf2_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=tf1.train.Scaffold(saver=ckpt)
  )

!rm -rf est-tf2
est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',
                             warm_start_from='est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf2')
print_checkpoint(latest_checkpoint)  

assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'est-tf2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='est-tf1', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: est-tf1
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 1 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into est-tf2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into est-tf2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Loss for final step: 1.0.
Checkpoint at 'est-tf2/model.ckpt-1':
  (key='var_list/var/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=4.0)
  (key='step/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n\x1c\n\x08\x08\x01\x12\x04step\n\x0c\x08\x02\x12\x08var_list*\x02\x08\x01\nD\x12>\n\x0eVARIABLE_VALUE\x12\x0bglobal_step\x1a\x1fstep/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n\r\n\x07\x08\x03\x12\x03var*\x02\x08\x01\nD\x12>\n\x0eVARIABLE_VALUE\x12\x03var\x1a'var_list/var/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

v の最終的な値は、est-tf1 からウォームスタートし、さらに 5 ステップのトレーニングを行った後、16 になるはずです。トレーニングステップの値は、warm_start チェックポイントから引き継がれません。

Keras のチェックポイントを設定する

Keras で構築されたモデルは、引き続き tf1.train.Savertf.train.Checkpoint を使用して既存の重みを読み込みます。モデルの移行が完了したら、特にトレーニング時に ModelCheckpoint コールバックを使用している場合は、model.save_weightsmodel.load_weights を使用するように切り替えます。

チェックポイントと Keras について知っておくべきこと:

初期化と構築

Keras のモデルとレイヤーは、作成を完了する前に 2 つのステップが必要があります。1 つ目は、Python オブジェクトの 初期化: layer = tf.keras.layers.Dense(x) です。2 番目は 構築ステップ layer.build(input_shape) で、ほとんどの重みが実際に作成されます。モデルを呼び出すか、単一の traineval、または predict ステップを実行してモデルを構築することもできます(初回のみ)。

model.load_weights(path).assert_consumed() でエラーが発生している場合は、モデル/レイヤーが構築されていない可能性があります。

Keras は TF2 チェックポイントを使用する

tf.train.Checkpoint(model).writemodel.save_weights と同等です。また、tf.train.Checkpoint(model).readmodel.load_weights と同等です。Checkpoint(model) != Checkpoint(model=model) であることに注意してください。

TF2 チェックポイントは Keras の build() ステップで機能する

tf.train.Checkpoint.restore には、遅延復元と呼ばれるメカニズムがあります。これにより、変数がまだ作成されていない場合、tf.Module と Keras オブジェクトが変数値を格納できるようになり、初期化されたモデルが重みを読み込んでから構築できるようになります。

m = YourKerasModel()
status = m.load_weights(path)

# This call builds the model. The variables are created with the restored
# values.
m.predict(inputs)

status.assert_consumed()

このメカニズムのため、Keras モデルで TF2 チェックポイント読み込み API を使用することを強くお勧めします(既存の TF1 チェックポイントをモデルマッピング shim に復元する場合でも)。詳しくはチェックポイントガイドを参照してください。

コード スニペット

以下のスニペットは、チェックポイント保存 API における TF1/TF2 バージョンの互換性を示しています。

TF1 チェックポイントを TF2 に保存する

a = tf.Variable(1.0, name='a')
b = tf.Variable(2.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(3.0, name='c')

saver = tf1.train.Saver(var_list=[a, b, c])
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

TF1 チェックポイントを TF2 に 読み込む

a = tf.Variable(0., name='a')
b = tf.Variable(0., name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0., name='c')
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt-saved-in-eager
Restored [a, b, c]:  [1.0, 2.0, 3.0]

TF1 に TF2 チェックポイントを保存する

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')
    print_checkpoint(tf2_in_tf1_path)
Checkpoint at 'tf2-ckpt-saved-in-session-1':
  (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n$\n\x0c\x08\x01\x12\x08var_list\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x05\x08\x03\x12\x01a\n\x05\x08\x04\x12\x01b\n\x0c\x08\x05\x12\x08scoped/c*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nO\x12I\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

TF1 に TF2 チェックポイントを読み込む

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()
    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

チェックポイント変換

チェックポイントを読み込んで再保存することにより、TF1 と TF2 の間でチェックポイントを変換できます。また、代替手段として、tf.train.load_checkpoint を使用できます。以下のコードに示します。

TF1 チェックポイントを TF2 に変換する

def convert_tf1_to_tf2(checkpoint_path, output_prefix):
  """Converts a TF1 checkpoint to TF2.

  To load the converted checkpoint, you must build a dictionary that maps
  variable names to variable objects.
  ```
  ckpt = tf.train.Checkpoint(vars={name: variable})  
  ckpt.restore(converted_ckpt_path)

    ```

    Args:
      checkpoint_path: Path to the TF1 checkpoint.
      output_prefix: Path prefix to the converted checkpoint.

    Returns:
      Path to the converted checkpoint.
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      vars[key] = tf.Variable(reader.get_tensor(key))
    return tf.train.Checkpoint(vars=vars).save(output_prefix)
  ```

スニペット `Save a TF1 checkpoint in TF2` に保存されているチェックポイントを変換します。


```python
# Make sure to run the snippet in `Save a TF1 checkpoint in TF2`.
print_checkpoint('tf1-ckpt-saved-in-eager')
converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', 
                                     'converted-tf1-to-tf2')
print("\n[Converted]")
print_checkpoint(converted_path)

# Try loading the converted checkpoint.
a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c})
ckpt.restore(converted_path).assert_consumed()
print("\nRestored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

[Converted]
Checkpoint at 'converted-tf1-to-tf2-1':
  (key='vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='vars/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='vars/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n \n\x08\x08\x01\x12\x04vars\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x0c\x08\x03\x12\x08scoped/c\n\x05\x08\x04\x12\x01b\n\x05\x08\x05\x12\x01a*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nK\x12E\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a)vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nC\x12=\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a!vars/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nC\x12=\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a!vars/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

Restored [a, b, c]:  [1.0, 2.0, 3.0]

TF2 チェックポイントを TF1 に変換する

def convert_tf2_to_tf1(checkpoint_path, output_prefix):
  """Converts a TF2 checkpoint to TF1.

  The checkpoint must be saved using a 
  `tf.train.Checkpoint(var_list={name: variable})`

  To load the converted checkpoint with `tf.compat.v1.Saver`:
  ```
  saver = tf.compat.v1.train.Saver(var_list={name: variable}) 

  # An alternative, if the variable names match the keys:
  saver = tf.compat.v1.train.Saver(var_list=[variables]) 
  saver.restore(sess, output_path)

    ```
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      # Get the "name" from the 
      if key.startswith('var_list/'):
        var_name = key.split('/')[1]
        # TF2 checkpoint keys use '/', so if they appear in the user-defined name,
        # they are escaped to '.S'.
        var_name = var_name.replace('.S', '/')
        vars[var_name] = tf.Variable(reader.get_tensor(key))

    return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)
  ```

スニペット `Save a TF2 checkpoint in TF1` に保存されているチェックポイントを変換します。


```python
# Make sure to run the snippet in `Save a TF2 checkpoint in TF1`.
print_checkpoint('tf2-ckpt-saved-in-session-1')
converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1',
                                    'converted-tf2-to-tf1')
print("\n[Converted]")
print_checkpoint(converted_path)

# Try loading the converted checkpoint.
with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    saver = tf1.train.Saver([a, b, c])
    saver.restore(sess, converted_path)
    print("\nRestored [a, b, c]: ", sess.run([a, b, c]))
Checkpoint at 'tf2-ckpt-saved-in-session-1':
  (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n$\n\x0c\x08\x01\x12\x08var_list\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x05\x08\x03\x12\x01a\n\x05\x08\x04\x12\x01b\n\x0c\x08\x05\x12\x08scoped/c*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nO\x12I\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.

[Converted]
Checkpoint at 'converted-tf2-to-tf1':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)
INFO:tensorflow:Restoring parameters from converted-tf2-to-tf1

Restored [a, b, c]:  [1.0, 2.0, 3.0]

関連ガイド