Kerasモデルの保存と読み込み

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

はじめに

Keras モデルは以下の複数のコンポーネントで構成されています。

  • アーキテクチャー/構成(モデルに含まれるレイヤーとそれらの接続方法を指定する)
  • 重み値のセット(「モデルの状態」)
  • オプティマイザ(モデルのコンパイルで定義する)
  • 損失とメトリックのセット(モデルのコンパイルで定義するか、add_loss()またはadd_metric()を呼び出して定義する)

Keras API を使用すると、これらを一度にディスクに保存したり、一部のみを選択して保存できます。

  • すべてを TensorFlow SavedModel 形式(または古い Keras H5 形式)で1つのアーカイブに保存。これは標準的な方法です。
  • アーキテクチャ/構成のみを(通常、JSON ファイルとして)保存。
  • 重み値のみを保存。(通常、モデルのトレーニング時に使用)。

では、次にこれらのオプションの用途と機能をそれぞれ見ていきましょう。

保存と読み込みに関する簡単な説明

このガイドを読む時間が 10 秒しかない場合は、次のことを知っておく必要があります。

Keras モデルの保存

model = ...  # Get model (Sequential, Functional Model, or Model subclass) model.save('path/to/location')

モデルの再読み込み

from tensorflow import keras model = keras.models.load_model('path/to/location')

では、詳細を見てみましょう。

セットアップ

import numpy as np
import tensorflow as tf
from tensorflow import keras
2022-08-09 05:28:42.265258: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-09 05:28:42.817786: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 05:28:42.818021: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 05:28:42.818033: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

モデル全体の保存と読み込み

モデル全体を1つのアーティファクトとして保存できます。その場合、以下が含まれます。

  • モデルのアーキテクチャ/構成
  • モデルの重み値(トレーニング時に学習される)
  • モデルのコンパイル情報(compile()が呼び出された場合)
  • オプティマイザとその状態(存在する場合)。これは、中断した所からトレーニングを再開するために使用します。

API

モデル全体をディスクに保存するには {nbsp}TensorFlow SavedModel 形式古い Keras H5 形式の 2 つの形式を使用できます。推奨される形式は SavedModel です。これは、model.save()を使用する場合のデフォルトです。

次の方法で H5 形式に切り替えることができます。

  • save_format='h5'save()に渡す。
  • .h5または.kerasで終わるファイル名をsave()に渡す。

SavedModel 形式

例:

def get_model():
    # Create a simple model.
    inputs = keras.Input(shape=(32,))
    outputs = keras.layers.Dense(1)(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer="adam", loss="mean_squared_error")
    return model


model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model')` creates a SavedModel folder `my_model`.
model.save("my_model")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_model")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)
4/4 [==============================] - 0s 2ms/step - loss: 0.2990
INFO:tensorflow:Assets written to: my_model/assets
4/4 [==============================] - 0s 2ms/step
4/4 [==============================] - 0s 2ms/step
4/4 [==============================] - 0s 2ms/step - loss: 0.2952
<keras.callbacks.History at 0x7f1a9c2efe80>

SavedModel に含まれるもの

model.save('my_model')を呼び出すと、以下を含むmy_modelという名前のフォルダが作成されます。

ls my_model
assets  keras_metadata.pb  saved_model.pb  variables

モデルアーキテクチャとトレーニング構成(オプティマイザ、損失、メトリックを含む)は、saved_model.pbに格納されます。重みはvariables/ディレクトリに保存されます。

SavedModel 形式についての詳細は「SavedModel ガイド(ディスク上の SavedModel 形式」)をご覧ください。

SavedModel によるカスタムオブジェクトの処理

モデルとそのレイヤーを保存する場合、SavedModel 形式はクラス名、呼び出し関数、損失、および重み(および実装されている場合は構成)を保存します。呼び出し関数は、モデル/レイヤーの計算グラフを定義します。

モデル/レイヤーの構成がない場合、トレーニング、評価、および推論に使用できる元のモデルのようなモデルを作成するために呼び出し関数が使用されます。

しかしながら、カスタムモデルまたはレイヤークラスを作成する場合は、常にget_configおよびfrom_configメソッドを使用して定義することをお勧めします。これにより、必要に応じて後で計算を簡単に更新できます。詳細については「カスタムオブジェクト」をご覧ください。

以下は、**config メソッドを上書きせずに **SavedModel 形式からカスタムレイヤーを読み込んだ場合の例です。

class CustomModel(keras.Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]

    def call(self, inputs):
        x = inputs
        for layer in self.dense_layers:
            x = layer(x)
        return x


model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs = model(input_arr)
model.save("my_model")

# Delete the custom-defined model class to ensure that the loader does not have
# access to it.
del CustomModel

loaded = keras.models.load_model("my_model")
np.testing.assert_allclose(loaded(input_arr), outputs)

print("Original model:", model)
print("Loaded model:", loaded)
INFO:tensorflow:Assets written to: my_model/assets
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Original model: <__main__.CustomModel object at 0x7f1ac4ab47f0>
Loaded model: <keras.saving.saved_model.load.CustomModel object at 0x7f1bada9aa00>

上記の例のように、ローダーは、元のモデルのように機能する新しいモデルクラスを動的に作成します。

Keras H5 形式

Keras は、モデルのアーキテクチャ、重み値、およびcompile()情報を含む1つの HDF5 ファイルの保存もサポートしています。これは、SavedModel に代わる軽量な形式です。

例:

model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.
model.save("my_h5_model.h5")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_h5_model.h5")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)
4/4 [==============================] - 0s 2ms/step - loss: 0.6498
4/4 [==============================] - 0s 1ms/step
4/4 [==============================] - 0s 2ms/step
4/4 [==============================] - 0s 2ms/step - loss: 0.5706
<keras.callbacks.History at 0x7f1bac938ca0>

制限事項

SavedModel 形式と比較して、H5 ファイルに含まれないものが 2 つあります。

  • SavedModel とは異なり、model.add_loss()およびmodel.add_metric()を介して追加された外部損失およびメトリックは保存されません。モデルにそのような損失とメトリックがあり、トレーニングを再開する場合は、モデルを読み込んだ後、これらの損失を自分で追加する必要があります。これは、self.add_loss()およびself.add_metric()を介して内部レイヤーで作成された損失/メトリックには適用されないことに注意してください。レイヤーが読み込まれる限り、これらの損失とメトリックはレイヤーのcallメソッドの一部であるため保持されます。
  • カスタムレイヤーなどのカスタムオブジェクトの計算グラフは、保存されたファイルに含まれません。読み込む際に、Keras はモデルを再構築するためにこれらのオブジェクトの Python クラス/関数にアクセスする必要があります。詳細については、「カスタムオブジェクト」をご覧ください。

アーキテクチャの保存

モデルの構成(アーキテクチャ)は、モデルに含まれるレイヤー、およびこれらのレイヤーの接続方法を指定します*。モデルの構成がある場合、コンパイル情報なしで、重みが新しく初期化された状態でモデルを作成することができます。

*これは、サブクラス化されたモデルではなく、Functional または Sequential API を使用して定義されたモデルにのみ適用されることに注意してください。

Sequential モデルまたは Functional API モデルの構成

これらのタイプのモデルは、レイヤーの明示的なグラフです。それらの構成は常に構造化された形式で提供されます。

API

get_config()およびfrom_config()

config = model.get_config()を呼び出すと、モデルの構成を含むPython dictが返されます。その後、同じモデルをSequential.from_config(config)(
Sequentialモデルの場合)またはModel.from_config(config)(Functional API モデルの場合) で再度構築できます。

同じワークフローは、シリアル化可能なレイヤーでも使用できます。

レイヤーの例:

layer = keras.layers.Dense(3, activation="relu")
layer_config = layer.get_config()
new_layer = keras.layers.Dense.from_config(layer_config)

Sequential モデルの例:

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
config = model.get_config()
new_model = keras.Sequential.from_config(config)

Functional モデルの例:

inputs = keras.Input((32,))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
config = model.get_config()
new_model = keras.Model.from_config(config)

to_json()およびtf.keras.models.model_from_json()

これは、get_config / from_configと似ていますが、モデルを JSON 文字列に変換します。この文字列は、元のモデルクラスなしで読み込めます。また、これはモデル固有であり、レイヤー向けではありません。

例:

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
json_config = model.to_json()
new_model = keras.models.model_from_json(json_config)

カスタムオブジェクト

モデルとレイヤー

サブクラス化されたモデルとレイヤーのアーキテクチャは、メソッド__init__およびcallで定義されています。それらは Python バイトコードと見なされ、JSON と互換性のある構成にシリアル化できません。pickleなどを使用してバイトコードのシリアル化を試すことができますが、これは安全ではなく、モデルを別のシステムに読み込むことはできません。

カスタム定義されたレイヤーのあるモデル、またはサブクラス化されたモデルを保存/読み込むには、get_configおよびfrom_config(オプション) メソッドを上書きする必要があります。さらに、Keras が認識できるように、カスタムオブジェクトの登録を使用する必要があります。

カスタム関数

カスタム定義関数 (アクティブ化の損失や初期化など) には、get_configメソッドは必要ありません。カスタムオブジェクトとして登録されている限り、関数名は読み込みに十分です。

TensorFlow グラフのみの読み込み

Keras により生成された TensorFlow グラフを以下のように読み込むことができます。 その場合、custom_objectsを提供する必要はありません。

model.save("my_model")
tensorflow_graph = tf.saved_model.load("my_model")
x = np.random.uniform(size=(4, 32)).astype(np.float32)
predicted = tensorflow_graph(x).numpy()
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: my_model/assets

この方法にはいくつかの欠点があることに注意してください。

  • 再作成できないモデルをプロダクションにロールアウトしないように、履歴を追跡するために使用されたカスタムオブジェクトに常にアクセスできる必要があります。
  • tf.saved_model.loadにより返されるオブジェクトは、Keras モデルではないので、簡単には使えません。たとえば、.predict().fit()へのアクセスはありません。

この方法は推奨されていませんが、カスタムオブジェクトのコードを紛失した場合やtf.keras.models.load_model()でモデルを読み込む際に問題が発生した場合などに役に立ちます。

詳細は、tf.saved_model.loadに関するページをご覧ください。

構成メソッドの定義

仕様:

  • get_configは、Keras のアーキテクチャおよびモデルを保存する API と互換性があるように、JSON シリアル化可能なディクショナリを返す必要があります。
  • from_config(config) (classmethod) は、構成から作成された新しいレイヤーまたはモデルオブジェクトを返します。デフォルトの実装は cls(**config)を返します。

例:

class CustomLayer(keras.layers.Layer):
    def __init__(self, a):
        self.var = tf.Variable(a, name="var_a")

    def call(self, inputs, training=False):
        if training:
            return inputs * self.var
        else:
            return inputs

    def get_config(self):
        return {"a": self.var.numpy()}

    # There's actually no need to define `from_config` here, since returning
    # `cls(**config)` is the default behavior.
    @classmethod
    def from_config(cls, config):
        return cls(**config)


layer = CustomLayer(5)
layer.var.assign(2)

serialized_layer = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
    serialized_layer, custom_objects={"CustomLayer": CustomLayer}
)

カスタムオブジェクトの登録

Keras は構成を生成したクラスについての情報を保持します。上記の例では、tf.keras.layers.serializeはシリアル化された形態のカスタムレイヤーを生成します。

{'class_name': 'CustomLayer', 'config': {'a': 2} }

Keras は、すべての組み込みのレイヤー、モデル、オプティマイザ、およびメトリッククラスのマスターリストを保持し、from_configを呼び出すための正しいクラスを見つけるために使用されます。クラスが見つからない場合は、エラー(Value Error: Unknown layer)が発生します。このリストにカスタムクラスを登録する方法は、いくつかあります。

  1. 読み込み関数でcustom_objects引数を設定する。(上記の「config メソッドの定義」セクションの例をご覧ください)
  2. tf.keras.utils.custom_object_scopeまたはtf.keras.utils.CustomObjectScope
  3. tf.keras.utils.register_keras_serializable

カスタムレイヤーと関数の例

class CustomLayer(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(CustomLayer, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(CustomLayer, self).get_config()
        config.update({"units": self.units})
        return config


def custom_activation(x):
    return tf.nn.tanh(x) ** 2


# Make a model with the CustomLayer and custom_activation
inputs = keras.Input((32,))
x = CustomLayer(32)(inputs)
outputs = keras.layers.Activation(custom_activation)(x)
model = keras.Model(inputs, outputs)

# Retrieve the config
config = model.get_config()

# At loading time, register the custom objects with a `custom_object_scope`:
custom_objects = {"CustomLayer": CustomLayer, "custom_activation": custom_activation}
with keras.utils.custom_object_scope(custom_objects):
    new_model = keras.Model.from_config(config)

メモリ内でモデルのクローンを作成する

また、tf.keras.models.clone_model()を通じて、メモリ内でモデルのクローンを作成できます。これは、構成を取得し、その構成からモデルを再作成する方法と同じです (したがって、コンパイル情報やレイヤーの重み値は保持されません)。

例:

with keras.utils.custom_object_scope(custom_objects):
    new_model = keras.models.clone_model(model)

モデルの重み値のみを保存および読み込む

モデルの重みのみを保存および読み込むように選択できます。これは次の場合に役立ちます。

  • 推論のためのモデルだけが必要とされる場合。この場合、トレーニングを再開する必要がないため、コンパイル情報やオプティマイザの状態は必要ありません。
  • 転移学習を行う場合。以前のモデルの状態を再利用して新しいモデルをトレーニングするため、以前のモデルのコンパイル情報は必要ありません。

インメモリの重みの移動のための API

異なるオブジェクト間で重みをコピーするにはget_weightsおよびset_weightsを使用します。

以下に例を示します。

インメモリで、1 つのレイヤーから別のレイヤーに重みを転送する

def create_layer():
    layer = keras.layers.Dense(64, activation="relu", name="dense_2")
    layer.build((None, 784))
    return layer


layer_1 = create_layer()
layer_2 = create_layer()

# Copy weights from layer 2 to layer 1
layer_2.set_weights(layer_1.get_weights())

インメモリで 1 つのモデルから互換性のあるアーキテクチャを備えた別のモデルに重みを転送する

# Create a simple functional model
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

# Define a subclassed model with the same architecture
class SubclassedModel(keras.Model):
    def __init__(self, output_dim, name=None):
        super(SubclassedModel, self).__init__(name=name)
        self.output_dim = output_dim
        self.dense_1 = keras.layers.Dense(64, activation="relu", name="dense_1")
        self.dense_2 = keras.layers.Dense(64, activation="relu", name="dense_2")
        self.dense_3 = keras.layers.Dense(output_dim, name="predictions")

    def call(self, inputs):
        x = self.dense_1(inputs)
        x = self.dense_2(x)
        x = self.dense_3(x)
        return x

    def get_config(self):
        return {"output_dim": self.output_dim, "name": self.name}


subclassed_model = SubclassedModel(10)
# Call the subclassed model once to create the weights.
subclassed_model(tf.ones((1, 784)))

# Copy weights from functional_model to subclassed_model.
subclassed_model.set_weights(functional_model.get_weights())

assert len(functional_model.weights) == len(subclassed_model.weights)
for a, b in zip(functional_model.weights, subclassed_model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

ステートレスレイヤーの場合

ステートレスレイヤーは重みの順序や数を変更しないため、ステートレスレイヤーが余分にある場合や不足している場合でも、モデルのアーキテクチャは互換性があります。

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)

# Add a dropout layer, which does not contain any weights.
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model_with_dropout = keras.Model(
    inputs=inputs, outputs=outputs, name="3_layer_mlp"
)

functional_model_with_dropout.set_weights(functional_model.get_weights())

重みをディスクに保存して再度読み込むための API

以下の形式でmodel.save_weightsを呼び出すことにより、重みをディスクに保存できます。

  • TensorFlow Checkpoint
  • HDF5

model.save_weightsのデフォルトの形式は TensorFlow Checkpoint です。保存形式を指定する方法は 2 つあります。

  1. save_format引数:値をsave_format = "tf"またはsave_format = "h5"に設定する。
  2. path引数:パスが.h5または.hdf5で終わる場合、HDF5 形式が使用されます。save_formatが設定されていない限り、他のサフィックスでは、TensorFlow Checkpoint になります。

また、オプションとしてインメモリの numpy 配列として重みを取得することもできます。各 API には、以下の長所と短所があります。

TF Checkpoint 形式

例:

# Runnable example
sequential_model = keras.Sequential(
    [
        keras.Input(shape=(784,), name="digits"),
        keras.layers.Dense(64, activation="relu", name="dense_1"),
        keras.layers.Dense(64, activation="relu", name="dense_2"),
        keras.layers.Dense(10, name="predictions"),
    ]
)
sequential_model.save_weights("ckpt")
load_status = sequential_model.load_weights("ckpt")

# `assert_consumed` can be used as validation that all variable values have been
# restored from the checkpoint. See `tf.train.Checkpoint.restore` for other
# methods in the Status object.
load_status.assert_consumed()
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f1bad8eb8b0>

形式の詳細

TensorFlow Checkpoint 形式は、オブジェクト属性名を使用して重みを保存および復元します。 たとえば、tf.keras.layers.Denseレイヤーを見てみましょう。このレイヤーには、2 つの重み、dense.kerneldense.biasがあります。レイヤーがtf形式で保存されると、結果のチェックポイントには、キー「kernel」「bias」およびそれらに対応する重み値が含まれます。 詳細につきましては、TF Checkpoint ガイドの「読み込みの仕組み」をご覧ください。

属性/グラフのエッジは、変数名ではなく、親オブジェクトで使用される名前で命名されていることに注意してください。以下の例のCustomLayerでは、変数CustomLayer.varは、"var_a"ではなく、"var"をキーの一部として保存されます。

class CustomLayer(keras.layers.Layer):
    def __init__(self, a):
        self.var = tf.Variable(a, name="var_a")


layer = CustomLayer(5)
layer_ckpt = tf.train.Checkpoint(layer=layer).save("custom_layer")

ckpt_reader = tf.train.load_checkpoint(layer_ckpt)

ckpt_reader.get_variable_to_dtype_map()
{'save_counter/.ATTRIBUTES/VARIABLE_VALUE': tf.int64,
 'layer/var/.ATTRIBUTES/VARIABLE_VALUE': tf.int32,
 '_CHECKPOINTABLE_OBJECT_GRAPH': tf.string}

転移学習の例

基本的に、2 つのモデルが同じアーキテクチャを持っている限り、同じチェックポイントを共有できます。

例:

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

# Extract a portion of the functional model defined in the Setup section.
# The following lines produce a new model that excludes the final output
# layer of the functional model.
pretrained = keras.Model(
    functional_model.inputs, functional_model.layers[-1].input, name="pretrained_model"
)
# Randomly assign "trained" weights.
for w in pretrained.weights:
    w.assign(tf.random.normal(w.shape))
pretrained.save_weights("pretrained_ckpt")
pretrained.summary()

# Assume this is a separate program where only 'pretrained_ckpt' exists.
# Create a new functional model with a different output dimension.
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(5, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="new_model")

# Load the weights from pretrained_ckpt into model.
model.load_weights("pretrained_ckpt")

# Check that all of the pretrained weights have been loaded.
for a, b in zip(pretrained.weights, model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

print("\n", "-" * 50)
model.summary()

# Example 2: Sequential model
# Recreate the pretrained model, and load the saved weights.
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
pretrained_model = keras.Model(inputs=inputs, outputs=x, name="pretrained")

# Sequential example:
model = keras.Sequential([pretrained_model, keras.layers.Dense(5, name="predictions")])
model.summary()

pretrained_model.load_weights("pretrained_ckpt")

# Warning! Calling `model.load_weights('pretrained_ckpt')` won't throw an error,
# but will *not* work as expected. If you inspect the weights, you'll see that
# none of the weights will have loaded. `pretrained_model.load_weights()` is the
# correct method to call.
Model: "pretrained_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense_1 (Dense)             (None, 64)                50240     
                                                                 
 dense_2 (Dense)             (None, 64)                4160      
                                                                 
=================================================================
Total params: 54,400
Trainable params: 54,400
Non-trainable params: 0
_________________________________________________________________

 --------------------------------------------------
Model: "new_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense_1 (Dense)             (None, 64)                50240     
                                                                 
 dense_2 (Dense)             (None, 64)                4160      
                                                                 
 predictions (Dense)         (None, 5)                 325       
                                                                 
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 pretrained (Functional)     (None, 64)                54400     
                                                                 
 predictions (Dense)         (None, 5)                 325       
                                                                 
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f1bad8c4340>

通常、モデルの作成には同じ API を使用することをお勧めします。Sequential と Functional、またはFunctional とサブクラス化などの間で切り替える場合は、常に事前トレーニング済みモデルを再構築し、事前トレーニング済みの重みをそのモデルに読み込みます。

モデルのアーキテクチャがまったく異なる場合は、どうすれば重みを保存して異なるモデルに読み込むことができるのでしょうか?tf.train.Checkpointを使用すると、正確なレイヤー/変数を保存および復元することができます。

例:

# Create a subclassed model that essentially uses functional_model's first
# and last layers.
# First, save the weights of functional_model's first and last dense layers.
first_dense = functional_model.layers[1]
last_dense = functional_model.layers[-1]
ckpt_path = tf.train.Checkpoint(
    dense=first_dense, kernel=last_dense.kernel, bias=last_dense.bias
).save("ckpt")

# Define the subclassed model.
class ContrivedModel(keras.Model):
    def __init__(self):
        super(ContrivedModel, self).__init__()
        self.first_dense = keras.layers.Dense(64)
        self.kernel = self.add_variable("kernel", shape=(64, 10))
        self.bias = self.add_variable("bias", shape=(10,))

    def call(self, inputs):
        x = self.first_dense(inputs)
        return tf.matmul(x, self.kernel) + self.bias


model = ContrivedModel()
# Call model on inputs to create the variables of the dense layer.
_ = model(tf.ones((1, 784)))

# Create a Checkpoint with the same structure as before, and load the weights.
tf.train.Checkpoint(
    dense=model.first_dense, kernel=model.kernel, bias=model.bias
).restore(ckpt_path).assert_consumed()
/tmpfs/tmp/ipykernel_42242/1562824211.py:15: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  self.kernel = self.add_variable("kernel", shape=(64, 10))
/tmpfs/tmp/ipykernel_42242/1562824211.py:16: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  self.bias = self.add_variable("bias", shape=(10,))
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f1bad8cdc70>

HDF5 形式

HDF5 形式には、レイヤー名でグループ化された重みが含まれています。重みは、トレーニング可能な重みのリストをトレーニング不可能な重みのリストに連結することによって並べられたリストです(layer.weightsと同じ)。 したがって、チェックポイントに保存されているものと同じレイヤーとトレーニング可能な状態がある場合、モデルは HDF 5 チェックポイントを使用できます。

例:

# Runnable example
sequential_model = keras.Sequential(
    [
        keras.Input(shape=(784,), name="digits"),
        keras.layers.Dense(64, activation="relu", name="dense_1"),
        keras.layers.Dense(64, activation="relu", name="dense_2"),
        keras.layers.Dense(10, name="predictions"),
    ]
)
sequential_model.save_weights("weights.h5")
sequential_model.load_weights("weights.h5")

ネストされたレイヤーがモデルに含まれている場合、layer.trainableを変更すると、layer.weightsの順序が異なる場合があることに注意してください。

class NestedDenseLayer(keras.layers.Layer):
    def __init__(self, units, name=None):
        super(NestedDenseLayer, self).__init__(name=name)
        self.dense_1 = keras.layers.Dense(units, name="dense_1")
        self.dense_2 = keras.layers.Dense(units, name="dense_2")

    def call(self, inputs):
        return self.dense_2(self.dense_1(inputs))


nested_model = keras.Sequential([keras.Input((784,)), NestedDenseLayer(10, "nested")])
variable_names = [v.name for v in nested_model.weights]
print("variables: {}".format(variable_names))

print("\nChanging trainable status of one of the nested layers...")
nested_model.get_layer("nested").dense_1.trainable = False

variable_names_2 = [v.name for v in nested_model.weights]
print("\nvariables: {}".format(variable_names_2))
print("variable ordering changed:", variable_names != variable_names_2)
variables: ['nested/dense_1/kernel:0', 'nested/dense_1/bias:0', 'nested/dense_2/kernel:0', 'nested/dense_2/bias:0']

Changing trainable status of one of the nested layers...

variables: ['nested/dense_2/kernel:0', 'nested/dense_2/bias:0', 'nested/dense_1/kernel:0', 'nested/dense_1/bias:0']
variable ordering changed: True

転移学習の例

HDF5 から事前トレーニングされた重みを読み込む場合は、元のチェックポイントモデルに重みを読み込んでから、目的の重み/レイヤーを新しいモデルに抽出することをお勧めします。

例:

def create_functional_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
    x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
    outputs = keras.layers.Dense(10, name="predictions")(x)
    return keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")


functional_model = create_functional_model()
functional_model.save_weights("pretrained_weights.h5")

# In a separate program:
pretrained_model = create_functional_model()
pretrained_model.load_weights("pretrained_weights.h5")

# Create a new model by extracting layers from the original model:
extracted_layers = pretrained_model.layers[:-1]
extracted_layers.append(keras.layers.Dense(5, name="dense_3"))
model = keras.Sequential(extracted_layers)
model.summary()
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_1 (Dense)             (None, 64)                50240     
                                                                 
 dense_2 (Dense)             (None, 64)                4160      
                                                                 
 dense_3 (Dense)             (None, 5)                 325       
                                                                 
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________