分散ストラテジーを使ってモデルを保存して読み込む

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

トレーニング中にモデルを保存して読み込むことは一般的な作業です。Keras モデルの保存と読み込みに使用する API には、高レベル API と低レベル API の 2 つがあります。このチュートリアルでは、tf.distribute.Strategy を使用してる場合に SavedModel API を使用する方法を実演しています。SavedModel とシリアル化に関する一般的な情報は、SavedModel ガイドKeras モデルのシリアル化ガイドをご覧ください。では、簡単な例から始めましょう。

依存関係をインポートします。

import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()
2022-08-08 21:34:59.203095: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-08 21:34:59.963442: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-08 21:34:59.963707: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-08 21:34:59.963721: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

tf.distribute.Strategy を使ってデータとモデルを準備します。

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

モデルをトレーニングします。

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2022-08-08 21:35:06.800058: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
235/235 [==============================] - 10s 8ms/step - loss: 0.3460 - accuracy: 0.9015
Epoch 2/2
235/235 [==============================] - 2s 7ms/step - loss: 0.1058 - accuracy: 0.9693
<keras.callbacks.History at 0x7fb62a68bdf0>

モデルを保存して読み込む

作業に使用する単純なモデルの準備ができたので、保存と読み込みの API を見てみましょう。利用できる API には次の 2 種類があります。

Keras API

次に、Keras API を使ってモデルを保存して読み込む例を示します。

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

tf.distribute.Strategy を使用せずにモデルを復元します。

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
235/235 [==============================] - 1s 4ms/step - loss: 0.0733 - accuracy: 0.9788
Epoch 2/2
235/235 [==============================] - 1s 4ms/step - loss: 0.0535 - accuracy: 0.9841
<keras.callbacks.History at 0x7fb628401d30>

モデルを復元したら、compile() をもう一度呼び出すことなくそのトレーニングを続行できます。保存前にコンパイル済みであるからです。モデルは、TensorFlow の標準的な SavedModel プロと形式で保存されています。その他の詳細は、saved_model 形式ガイドをご覧ください。

tf.distribute.Strategy を使用して、モデルを読み込んでトレーニングします。

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2022-08-08 21:35:23.095706: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-08-08 21:35:23.161033: W tensorflow/core/framework/dataset.cc:769] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
235/235 [==============================] - 3s 12ms/step - loss: 0.0733 - accuracy: 0.9791
Epoch 2/2
235/235 [==============================] - 3s 12ms/step - loss: 0.0543 - accuracy: 0.9844

ご覧の通り、tf.distribute.Strategy を使って期待通りに読み込まれました。ここで使用されるストラテジーは、保存前に使用したストラテジーと同じものである必要はありません。

tf.saved_model API

では、低レベル API を見てみましょう。モデルの保存は Keras API に類似しています。

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

読み込みは tf.saved_model.load() で行えますが、より低いレベルにある API(したがって広範なユースケースのある API)であるため、Keras モデルを返しません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。次に例を示します。

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

読み込まれたオブジェクトには複数の関数が含まれ、それぞれにキーが関連付けられている可能性があります。"serving_default" は、保存された Keras モデルを使用した推論関数のデフォルトのキーです。この関数で推論を実行するには、次のようにします。

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[-0.07606123, -0.1056785 ,  0.25221872, ...,  0.03360194,
         0.19555938,  0.13527356],
       [ 0.03998773, -0.12563412,  0.0774684 , ...,  0.1610758 ,
         0.20585346,  0.34802556],
       [-0.05313269, -0.16544825,  0.08238681, ...,  0.13444236,
         0.06308331,  0.11146139],
       ...,
       [-0.00776181, -0.05475638,  0.10547257, ...,  0.3012747 ,
         0.09161591,  0.20536867],
       [ 0.04294423, -0.1845659 ,  0.12285194, ...,  0.300538  ,
         0.02631462,  0.36374742],
       [ 0.02375309, -0.17935638,  0.1645618 , ...,  0.2643962 ,
         0.08826756,  0.1868835 ]], dtype=float32)>}
2022-08-08 21:35:30.404755: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

また、分散方法で読み込んで推論を実行することもできます。

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
2022-08-08 21:35:30.620059: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

復元された関数の呼び出しは、保存されたモデル(predict)に対するフォワードパスです。読み込まれた関数をトレーニングし続ける場合はどうでしょうか。または読み込まれた関数をより大きなモデルに埋め込むには?一般的には、この読み込まれたオブジェクトを Keras レイヤーにラップして行うことができます。幸いにも、TF Hub には、次に示すとおり、この目的に使用できる hub.KerasLayer が用意されています。

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
2022-08-08 21:35:32.571835: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
235/235 [==============================] - 6s 7ms/step - loss: 0.3170 - accuracy: 0.9140
Epoch 2/2
235/235 [==============================] - 2s 7ms/step - loss: 0.0968 - accuracy: 0.9725

ご覧の通り、hub.KerasLayertf.saved_model.load() から読み込まれた結果を、別のモデルの構築に使用できる Keras レイヤーにラップしています。学習を転送する場合に非常に便利な手法です。

どの API を使用すべきですか?

保存の場合は、Keras モデルを使用しているのであれば、Keras の model.save() API をほぼ必ず使用することが推奨されます。保存しているものが Keras モデルでなければ、低レベル API しか使用できません。

読み込みの場合は、読み込み API から得ようとしているものによって選択肢がきまs理ます。Keras モデルを使用できない(または使用を希望しない)のであれば、tf.saved_model.load() を使用し、そうでなければ、tf.keras.models.load_model() を使用します。Keras モデルを保存した場合にのみ Keras モデルを読み込めることに注意してください。

API を混在させることも可能です。model.save で Keras モデルを保存し、低レベルの tf.saved_model.load API を使用して、非 Keras モデルを読み込むことができます。

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

ローカルデバイスで保存または読み込む

クラウド TPU を使用するなど、リモートで実行中にローカルの IO デバイスに保存したり、そこから読み込んだりする場合、experimental_io_device オプションを使用して、IO デバイスを localhost に設定する必要があります。

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

警告

特殊なケースは、入力が十分に定義されていない Keras モデルがある場合です。たとえば、Seeuqntial モデルは、入力形状(Sequential([Dense(3), ...])を使用せずに作成できます。Subclassed モデルにも、初期化後は十分に定義された入力がありません。この場合、保存と読み込みの両方に低レベル API を使用する必要があります。そうしない場合はエラーが発生します。

モデルの入力が十分に定義されたものであるかを確認するには、model.inputsNone であるかどうかを確認します。None でなければ問題ありません。入力形状は、モデルが .fit.evaluate.predict で使用されている場合、またはモデルを呼び出す場合(model(inputs))に自動的に定義されます。

次に例を示します。

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fb52802bcd0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fb52802bcd0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7fb52806a880>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7fb52806a880>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets