モデルの保存と復元

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

モデルの進行状況は、トレーニング中およびトレーニング後に保存できます。モデルが中断したところから再開できるので、長いトレーニング時間を回避できます。また、保存することによりモデルを共有したり、他の人による作業の再現が可能になります。研究モデルや手法を公開する場合、ほとんどの機械学習の実践者は次を共有します。

  • モデルを構築するプログラム
  • モデルのトレーニング済みモデルの重みやパラメータ

このデータを共有することで、他の人がモデルがどの様に動作するかを理解したり、新しいデータに試してみたりすることが容易になります。

注意: TensorFlow モデルはコードであり、信頼できないコードに注意する必要があります。詳細については、TensorFlow を安全に使用するをご覧ください。

オプション

TensorFlow モデルを保存するには、使用している API に応じて様々な方法があります。このガイドでは、TensorFlow でモデルのビルドとトレーニングを行う tf.keras という高レベル API を使用しています。このチュートリアルで使用されている新しい高レベル .keras 形式は、堅牢で効率的な名前ベースの保存方法を提供しており、通常、低レベルやレガシー形式よりも簡単にデバッグできるため、Keras オブジェクトの保存に推奨されています。より高度な保存またはシリアル化ワークフロー、特にカスタムオブジェクトが関わるワークフローについては、「Keras モデルを保存して読み込む」ガイドをご覧ください。他のアプローチについては、「SavedModel 形式の使用」ガイドをご覧ください。

セットアップ

インストールとインポート

TensorFlow をインストールし、依存関係インポートします。

pip install pyyaml h5py  # Required to save models in HDF5 format
import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)
2024-01-11 21:23:26.230036: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 21:23:26.230083: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 21:23:26.231707: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2.15.0

サンプルデータセットの取得

ここでは、重みの保存と読み込みをデモするために、MNIST データセットを使います。デモの実行を速くするため、最初の 1,000 件のサンプルだけを使います。

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

モデルの定義

簡単なシーケンシャルモデルを構築することから始めます。

# Define a simple sequential model
def create_model():
  model = tf.keras.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 512)               401920    
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                5130      
                                                                 
=================================================================
Total params: 407050 (1.55 MB)
Trainable params: 407050 (1.55 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

トレーニング中にチェックポイントを保存する

再トレーニングせずにトレーニング済みモデルを使用したり、トレーニングプロセスを中断したところから再開することもできます。tf.keras.callbacks.ModelCheckpoint コールバックを使用すると、トレーニング中でもトレーニングの終了時でもモデルを継続的に保存できます。

チェックポイントコールバックの使い方

トレーニング中にのみ重みを保存する tf.keras.callbacks.ModelCheckpoint コールバックを作成します。

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705008212.295762  622332 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
22/32 [===================>..........] - ETA: 0s - loss: 1.3640 - sparse_categorical_accuracy: 0.6065 
Epoch 1: saving model to training_1/cp.ckpt
32/32 [==============================] - 2s 13ms/step - loss: 1.1532 - sparse_categorical_accuracy: 0.6700 - val_loss: 0.7411 - val_sparse_categorical_accuracy: 0.7790
Epoch 2/10
22/32 [===================>..........] - ETA: 0s - loss: 0.4534 - sparse_categorical_accuracy: 0.8551
Epoch 2: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 7ms/step - loss: 0.4309 - sparse_categorical_accuracy: 0.8700 - val_loss: 0.5356 - val_sparse_categorical_accuracy: 0.8340
Epoch 3/10
23/32 [====================>.........] - ETA: 0s - loss: 0.2729 - sparse_categorical_accuracy: 0.9307
Epoch 3: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.2916 - sparse_categorical_accuracy: 0.9260 - val_loss: 0.5225 - val_sparse_categorical_accuracy: 0.8390
Epoch 4/10
22/32 [===================>..........] - ETA: 0s - loss: 0.2042 - sparse_categorical_accuracy: 0.9460
Epoch 4: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 7ms/step - loss: 0.2157 - sparse_categorical_accuracy: 0.9470 - val_loss: 0.4547 - val_sparse_categorical_accuracy: 0.8530
Epoch 5/10
23/32 [====================>.........] - ETA: 0s - loss: 0.1520 - sparse_categorical_accuracy: 0.9728
Epoch 5: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.1518 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.4463 - val_sparse_categorical_accuracy: 0.8570
Epoch 6/10
24/32 [=====================>........] - ETA: 0s - loss: 0.1357 - sparse_categorical_accuracy: 0.9714
Epoch 6: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.1311 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.4503 - val_sparse_categorical_accuracy: 0.8550
Epoch 7/10
23/32 [====================>.........] - ETA: 0s - loss: 0.0959 - sparse_categorical_accuracy: 0.9837
Epoch 7: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0916 - sparse_categorical_accuracy: 0.9840 - val_loss: 0.4363 - val_sparse_categorical_accuracy: 0.8600
Epoch 8/10
23/32 [====================>.........] - ETA: 0s - loss: 0.0650 - sparse_categorical_accuracy: 0.9905
Epoch 8: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0649 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4235 - val_sparse_categorical_accuracy: 0.8680
Epoch 9/10
23/32 [====================>.........] - ETA: 0s - loss: 0.0536 - sparse_categorical_accuracy: 0.9973
Epoch 9: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0523 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.4353 - val_sparse_categorical_accuracy: 0.8580
Epoch 10/10
24/32 [=====================>........] - ETA: 0s - loss: 0.0402 - sparse_categorical_accuracy: 1.0000
Epoch 10: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0393 - sparse_categorical_accuracy: 0.9990 - val_loss: 0.4228 - val_sparse_categorical_accuracy: 0.8650
<keras.src.callbacks.History at 0x7f0b7c0c5700>

この結果、エポックごとに更新される一連のTensorFlowチェックポイントファイルが作成されます。

os.listdir(checkpoint_dir)
['cp.ckpt.data-00000-of-00001', 'checkpoint', 'cp.ckpt.index']

2 つのモデルが同じアーキテクチャを共有している限り、それらの間で重みを共有できます。したがって、重みのみからモデルを復元する場合は、元のモデルと同じアーキテクチャでモデルを作成してから、その重みを設定します。

次に、トレーニングされていない新しいモデルを再構築し、テストセットで評価します。トレーニングされていないモデルは、偶然誤差(10% 以下の正解率)で実行されます。

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.3377 - sparse_categorical_accuracy: 0.1260 - 179ms/epoch - 6ms/step
Untrained model, accuracy: 12.60%

次に、チェックポイントから重みをロードし、再び評価します。

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4228 - sparse_categorical_accuracy: 0.8650 - 90ms/epoch - 3ms/step
Restored model, accuracy: 86.50%

チェックポイントコールバックのオプション

このコールバックには、チェックポイントに一意な名前をつけたり、チェックポイントの頻度を調整するためのオプションがあります。

新しいモデルをトレーニングし、5 エポックごとに一意な名前のチェックポイントを保存します。

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

# Calculate the number of batches per epoch
import math
n_batches = len(train_images) / batch_size
n_batches = math.ceil(n_batches)    # round up the number of batches to the nearest whole integer

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*n_batches)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt

Epoch 10: saving model to training_2/cp-0010.ckpt

Epoch 15: saving model to training_2/cp-0015.ckpt

Epoch 20: saving model to training_2/cp-0020.ckpt

Epoch 25: saving model to training_2/cp-0025.ckpt

Epoch 30: saving model to training_2/cp-0030.ckpt

Epoch 35: saving model to training_2/cp-0035.ckpt

Epoch 40: saving model to training_2/cp-0040.ckpt

Epoch 45: saving model to training_2/cp-0045.ckpt

Epoch 50: saving model to training_2/cp-0050.ckpt
<keras.src.callbacks.History at 0x7f0b7416df70>

次に、できあがったチェックポイントをレビューし、最新のものを選択します。

os.listdir(checkpoint_dir)
['cp-0040.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.index',
 'cp-0010.ckpt.data-00000-of-00001',
 'cp-0015.ckpt.index',
 'cp-0025.ckpt.index',
 'cp-0040.ckpt.index',
 'cp-0005.ckpt.data-00000-of-00001',
 'cp-0005.ckpt.index',
 'cp-0000.ckpt.index',
 'cp-0030.ckpt.index',
 'cp-0050.ckpt.index',
 'cp-0020.ckpt.data-00000-of-00001',
 'cp-0025.ckpt.data-00000-of-00001',
 'cp-0010.ckpt.index',
 'cp-0045.ckpt.data-00000-of-00001',
 'cp-0045.ckpt.index',
 'cp-0050.ckpt.data-00000-of-00001',
 'checkpoint',
 'cp-0000.ckpt.data-00000-of-00001',
 'cp-0015.ckpt.data-00000-of-00001',
 'cp-0020.ckpt.index',
 'cp-0030.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.data-00000-of-00001']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

注意: デフォルトの TensorFlow 形式では、最新の 5 つのチェックポイントのみが保存されます。

テストのため、モデルをリセットし最新のチェックポイントを読み込みます。

# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4859 - sparse_categorical_accuracy: 0.8750 - 179ms/epoch - 6ms/step
Restored model, accuracy: 87.50%

これらのファイルは何?

上記のコードは、バイナリ形式でトレーニングされた重みのみを含む checkpoint 形式のファイルのコレクションに重みを格納します。チェックポイントには、次のものが含まれます。

  • 1 つ以上のモデルの重みのシャード。
  • どの重みがどのシャードに格納されているかを示すインデックスファイル。

一台のマシンでモデルをトレーニングしている場合は、接尾辞が .data-00000-of-00001 のシャードが 1 つあります。

手動で重みを保存する

tf.keras.Model.save_weights を使用して、手動で重みを保存します。デフォルトでは、tf.keras、特に Model.save_weights メソッドは、.ckpt 拡張子を持つ TensorFlow Checkpoint 形式を使用します。.h5 拡張して HDF5 形式として保存するには、モデルを保存して読み込むガイドをご覧ください。

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4859 - sparse_categorical_accuracy: 0.8750 - 181ms/epoch - 6ms/step
Restored model, accuracy: 87.50%

モデル全体を保存する

tf.keras.Model.save を呼び出して、単一の model.keras zip アーカイブに、モデルのアーキテクチャ、重み、およびトレーニング構成を保存します。

モデル全体の保存は、3 つの異なる形式(新しい .keras 形式と 2 つのレガシー形式: SavedModelHDF5)で行えます。path/to/model.keras として保存すると、自動的に最新の形式で保存されます。

注意: Keras オブジェクトについては、新しい高レベルの .keras 形式を使用することが推奨されています。よりリッチで名前ベースの保存と再読み込みを行えるため、デバッグしやすいのが特徴です。既存のコードについては、低レベルの SavedModel 形式とレガシーの H5 形式が引き続きサポートされています。

次の方法で、SavedModel 形式に切り替えることができます。

  • save_format='tf'save() に渡す
  • 拡張子なしでファイル名を渡す

次の方法で H5 形式に切り替えることができます。

  • save_format='h5'save() に渡す
  • .h5 で終わるファイル名を渡す

Saving a fully-functional model is very useful—you can load them in TensorFlow.js (Saved Model, HDF5) and then train and run them in web browsers, or convert them to run on mobile devices using TensorFlow Lite (Saved Model, HDF5)

*Custom objects (for example, subclassed models or layers) require special attention when saving and loading. Refer to the Saving custom objects section below.

新しい高レベルの .keras 形式

新しい Keras v3 保存形式は .keras 拡張を使用し、名前ベースの保存を実装するよりシンプルで効率的な形式であるため、Python の観点から、読み込んだものが実際に保存したものであることが保証されます。これにより、デバッグをはるかに容易に行えるため、Keras に推奨される形式となっています。

以下のセクションは、.keras 形式でモデルを保存し、復元する方法を説明しています。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a `.keras` zip archive.
model.save('my_model.keras')
Epoch 1/5
32/32 [==============================] - 1s 3ms/step - loss: 1.1083 - sparse_categorical_accuracy: 0.6940
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4247 - sparse_categorical_accuracy: 0.8680
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9220
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2004 - sparse_categorical_accuracy: 0.9550
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1497 - sparse_categorical_accuracy: 0.9710

.keras zip アーカイブからフレッシュな Keras モデルを再読み込みします。

new_model = tf.keras.models.load_model('my_model.keras')

# Show the model architecture
new_model.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_10 (Dense)            (None, 512)               401920    
                                                                 
 dropout_5 (Dropout)         (None, 512)               0         
                                                                 
 dense_11 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407050 (1.55 MB)
Trainable params: 407050 (1.55 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

読み込まれたモデルで評価と予測を実行してみましょう。

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4726 - sparse_categorical_accuracy: 0.8500 - 180ms/epoch - 6ms/step
32/32 [==============================] - 0s 1ms/step
(1000, 10)

SavedModel 形式

SavedModel 形式は、モデルをシリアル化するもう 1 つの方法です。この形式で保存されたモデルは、tf.keras.models.load_model を使用して復元でき、TensorFlow Serving と互換性があります。SavedModel をサービングおよび検査する方法についての詳細は、SavedModel ガイドを参照してください。以下のセクションでは、モデルを保存および復元する手順を示します。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5
32/32 [==============================] - 1s 3ms/step - loss: 1.1850 - sparse_categorical_accuracy: 0.6680
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4369 - sparse_categorical_accuracy: 0.8670
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2892 - sparse_categorical_accuracy: 0.9270
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2175 - sparse_categorical_accuracy: 0.9430
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1508 - sparse_categorical_accuracy: 0.9690
INFO:tensorflow:Assets written to: saved_model/my_model/assets
INFO:tensorflow:Assets written to: saved_model/my_model/assets

SavedModel 形式は、protobuf バイナリと TensorFlow チェックポイントを含むディレクトリです。保存されたモデルディレクトリを調べます。

# my_model directory
ls saved_model

# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model
assets  fingerprint.pb  keras_metadata.pb  saved_model.pb  variables

保存したモデルから新しい Keras モデルを再度読み込みます。

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_12 (Dense)            (None, 512)               401920    
                                                                 
 dropout_6 (Dropout)         (None, 512)               0         
                                                                 
 dense_13 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407050 (1.55 MB)
Trainable params: 407050 (1.55 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

復元されたモデルは、元のモデルと同じ引数でコンパイルされます。読み込まれたモデルで評価と予測を実行してみてください。

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4212 - sparse_categorical_accuracy: 0.8590 - 189ms/epoch - 6ms/step
32/32 [==============================] - 0s 1ms/step
(1000, 10)

HDF5 形式

Keras には、HDF5 標準を使用した基本的なレガシーの高レベル保存形式が備わっています。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5
32/32 [==============================] - 1s 2ms/step - loss: 1.1673 - sparse_categorical_accuracy: 0.6750
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4260 - sparse_categorical_accuracy: 0.8790
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2811 - sparse_categorical_accuracy: 0.9260
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2110 - sparse_categorical_accuracy: 0.9510
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1555 - sparse_categorical_accuracy: 0.9670
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.
  saving_api.save_model(

保存したファイルを使ってモデルを再作成します。

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()
Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_14 (Dense)            (None, 512)               401920    
                                                                 
 dropout_7 (Dropout)         (None, 512)               0         
                                                                 
 dense_15 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407050 (1.55 MB)
Trainable params: 407050 (1.55 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

正解率を検査します。

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4251 - sparse_categorical_accuracy: 0.8540 - 186ms/epoch - 6ms/step
Restored model, accuracy: 85.40%

Keras は、アーキテクチャを検査することでモデルを保存します。この手法ではすべてが保存されます。

  • 重みの値
  • モデルのアーキテクチャ
  • モデルのトレーニング構成(.compile() メソッドに渡すもの)
  • あれば、オプティマイザとその状態(中断した所からトレーニングを再開するため)

Keras は v1.x (tf.compat.v1.train にあります) のオプティマイザを保存できません。これらはチェックポイントと互換性がないためです。v1.x のオプティマイザでは、オプティマイザの状態を読み込ませてモデルを再度コンパイルする必要があります。

カスタムオブジェクトの保存

SavedModel 形式を使用している場合は、このセクションを省略できます。高レベルの .keras/HDF5 形式と低レベルの SavedModel 形式の違いは、.keras/HDF5 形式はオブジェクト構成を使用してモデルアーキテクチャを保存するのに対し、SavedModel は実行グラフを保存するという点です。したがって、SavedModel は、元のコードがなくても、サブクラス化されたモデルやカスタムレイヤーなどのカスタムオブジェクトを保存することができます。ただしこれにより、低レベルの SavedModels のデバッグはより困難であるため、名前ベースで、Keras ネイティブであるという特性を備えた高レベルの .keras 形式を代わりに使用することをお勧めします。

カスタムオブジェクトを .keras と HDF5 に保存するには、以下を実行します。

  1. オブジェクトで get_config メソッドを定義し、オプションで from_config クラスメソッドを定義します。
    • get_config(self) は、オブジェクトの再作成に必要なパラメータの JSON シリアル化可能なディクショナリを返します。
    • from_config(cls, config){/code0 }は、<code data-md-type="codespan">get_config から返された構成を使用して新しいオブジェクトを作成します。デフォルトでは、この関数は構成を初期化 kwargs (return cls(**config)) として使用します。
  2. 以下のいずれかの方法で、カスタムオブジェクトをモデルに渡します。
    • @tf.keras.utils.register_keras_serializable デコレータを使ってカスタムオブジェクトを登録します。(推奨)
    • モデルを読み込むときに、オブジェクトを直接 custom_objects 引数に渡します。引数は、文字列クラス名を Python クラスにマッピングするディクショナリである必要があります。(例: tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
    • tf.keras.utils.custom_object_scope を使用します。custom_objects ディクショナリ引数にオブジェクトを含め、範囲内に tf.keras.models.load_model(path) 呼び出しを配置します。

カスタムオブジェクトと get_config の例については、レイヤーとモデルを最初から作成するチュートリアルをご覧ください。

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.