![]() | ![]() | ![]() | ![]() |
「TensorFlowモデルを保存する」というフレーズは、通常、次の2つのいずれかを意味します。
- チェックポイント、または
- SavedModel。
チェックポイントは、モデルで使用されるすべてのパラメーター( tf.Variable
オブジェクト)の正確な値をキャプチャします。チェックポイントには、モデルによって定義された計算の説明が含まれていないため、通常、保存されたパラメーター値を使用するソースコードが利用可能な場合にのみ役立ちます。
一方、SavedModel形式には、パラメーター値(チェックポイント)に加えて、モデルによって定義された計算のシリアル化された記述が含まれます。この形式のモデルは、モデルを作成したソースコードから独立しています。したがって、TensorFlow Serving、TensorFlow Lite、TensorFlow.js、または他のプログラミング言語(C、C ++、Java、Go、Rust、C#などのTensorFlow API)のプログラムを介したデプロイに適しています。
このガイドでは、チェックポイントを読み書きするためのAPIについて説明します。
セットアップ
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
tf.keras
トレーニングAPIからの保存
保存と復元については、 tf.keras
ガイドを参照してください。
tf.keras.Model.save_weights
はTensorFlowチェックポイントを保存します。
net.save_weights('easy_checkpoint')
チェックポイントの作成
TensorFlowモデルの永続的な状態は、 tf.Variable
オブジェクトに保存されます。これらは直接構築できますが、多くの場合、tf.keras.layers
やtf.keras.Model
などの高レベルAPIを介して作成されます。
変数を管理する最も簡単な方法は、変数をPythonオブジェクトにアタッチしてから、それらのオブジェクトを参照することです。
tf.train.Checkpoint
、 tf.keras.layers.Layer
、およびtf.keras.Model
サブクラスは、それらの属性に割り当てられた変数を自動的に追跡します。次の例では、単純な線形モデルを作成し、モデルのすべての変数の値を含むチェックポイントを書き込みます。
Model.save_weights
、モデルチェックポイントを簡単に保存できます。
手動チェックポイント
セットアップ
tf.train.Checkpoint
すべての機能をtf.train.Checkpoint
ために、おもちゃのデータセットと最適化の手順を定義します。
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
チェックポイントオブジェクトを作成します
tf.train.Checkpoint
オブジェクトを使用して、チェックポイントを手動で作成します。チェックポイントを設定するオブジェクトは、オブジェクトの属性として設定されます。
tf.train.CheckpointManager
は、複数のチェックポイントの管理にも役立ちます。
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
モデルのトレーニングとチェックポイント
次のトレーニングループは、モデルとオプティマイザーのインスタンスを作成し、それらをtf.train.Checkpoint
オブジェクトに収集します。データの各バッチでループ内のトレーニングステップを呼び出し、定期的にチェックポイントをディスクに書き込みます。
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 30.42 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 23.83 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 17.27 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 10.81 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 4.74
トレーニングを復元して続行する
最初のトレーニングサイクルの後、新しいモデルとマネージャーを渡すことができますが、中断したところからトレーニングを開始します。
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 0.85 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 0.87 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.71 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.46 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.21
tf.train.CheckpointManager
オブジェクトは、古いチェックポイントを削除します。上記では、最新の3つのチェックポイントのみを保持するように構成されています。
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
これらのパス(例: './tf_ckpts/ckpt-10'
)は、ディスク上のファイルではありません。代わりに、 index
ファイルと変数値を含む1つ以上のデータファイルのプレフィックスです。これらのプレフィックスは、 CheckpointManager
がその状態を保存する単一のcheckpoint
ファイル( './tf_ckpts/checkpoint'
)にグループ化されます。
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
力学の読み込み
TensorFlowは、ロードされているオブジェクトから開始して、名前付きエッジを持つ有向グラフをトラバースすることにより、変数をチェックポイント値に一致させます。エッジ名は通常、オブジェクトの属性名からself.l1 = tf.keras.layers.Dense(5)
れます。たとえば、 self.l1 = tf.keras.layers.Dense(5)
の"l1"
です。 tf.train.Checkpoint
は、 tf.train.Checkpoint(step=...)
の"step"
ように、キーワード引数名を使用しtf.train.Checkpoint(step=...)
。
上記の例の依存関係グラフは次のようになります。
オプティマイザーは赤、通常の変数は青、オプティマイザーのスロット変数はオレンジです。他のノード(たとえば、 tf.train.Checkpoint
表す)は黒で表示されます。
スロット変数はオプティマイザーの状態の一部ですが、特定の変数用に作成されます。たとえば、上記の'm'
エッジは、Adamオプティマイザが各変数に対して追跡する運動量に対応します。スロット変数は、変数とオプティマイザーの両方が保存される場合にのみチェックポイントに保存されます。つまり、破線のエッジです。
tf.train.Checkpoint
オブジェクトでrestore
を呼び出すと、要求された復元がキューに入れられ、 Checkpoint
オブジェクトから一致するパスがtf.train.Checkpoint
とすぐに変数値が復元されます。たとえば、ネットワークとレイヤーを介してモデルへの1つのパスを再構築することにより、上記で定義したモデルからバイアスのみをロードできます。
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.831489 3.7156947 2.5892444 3.8669944 4.749503 ]
これらの新しいオブジェクトの依存関係グラフは、上記で作成した大きなチェックポイントのはるかに小さなサブグラフです。これには、 tf.train.Checkpoint
がチェックポイントに番号をtf.train.Checkpoint
使用するバイアスと保存カウンターのみが含まれます。
restore
は、オプションのアサーションを持つステータスオブジェクトを返します。新しいCheckpoint
作成されたすべてのオブジェクトが復元されたため、 status.assert_existing_objects_matched
がパスします。
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1644447b70>
チェックポイントには、レイヤーのカーネルやオプティマイザーの変数など、一致していないオブジェクトが多数あります。 status.assert_consumed
は、チェックポイントとプログラムがstatus.assert_consumed
一致する場合にのみ渡され、ここで例外をスローします。
復元の遅延
TensorFlowのLayer
オブジェクトは、入力シェイプが利用可能な場合、最初の呼び出しまで変数の作成を遅らせる可能性があります。たとえば、 Dense
レイヤーのカーネルの形状は、レイヤーの入力形状と出力形状の両方に依存するため、コンストラクター引数として必要な出力形状は、変数を単独で作成するのに十分な情報ではありません。 Layer
呼び出すと変数の値も読み取られるため、変数の作成から最初の使用までの間に復元を行う必要があります。
このイディオムをサポートするために、 tf.train.Checkpoint
は、一致する変数がまだない復元をキューにtf.train.Checkpoint
ます。
delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5719748 4.6099544 4.931875 4.836442 4.8496275]]
チェックポイントを手動で検査する
tf.train.load_checkpoint
は、 CheckpointReader
ポイントの内容への低レベルのアクセスを提供するCheckpointReader
を返します。これには、各変数のキーから、チェックポイント内の各変数の形状とdtypeへのマッピングが含まれています。上に表示されたグラフのように、変数のキーはそのオブジェクトパスです。
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
したがって、 net.l1.kernel
の値に関心がある場合は、次のコードで値を取得できます。
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
また、変数の値を検査できるget_tensor
メソッドも提供します。
reader.get_tensor(key)
array([[4.5719748, 4.6099544, 4.931875 , 4.836442 , 4.8496275]], dtype=float32)
リストと辞書の追跡
self.l1 = tf.keras.layers.Dense(5)
ような直接の属性割り当てと同様に、リストと辞書を属性に割り当てると、その内容が追跡されます。
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
リストや辞書のラッパーオブジェクトに気付くかもしれません。これらのラッパーは、基になるデータ構造のチェックポイント可能なバージョンです。属性ベースのロードと同様に、これらのラッパーは、変数がコンテナーに追加されるとすぐに変数の値を復元します。
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
同じ追跡がtf.keras.Model
サブクラスに自動的に適用され、たとえばレイヤーのリストを追跡するために使用できます。
概要
TensorFlowオブジェクトは、使用する変数の値を保存および復元するための簡単な自動メカニズムを提供します。