모델 저장과 복원

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기

모델 진행 상황은 훈련 중 및 훈련 후에 저장할 수 있습니다. 즉, 모델이 중단된 위치에서 다시 시작하고 긴 훈련 시간을 피할 수 있습니다. 저장은 또한 모델을 공유할 수 있고 다른 사람들이 작업을 다시 만들 수 있음을 의미합니다. 연구 모델 및 기술을 게시할 때 대부분의 머신러닝 실무자는 다음을 공유합니다.

  • 모델을 만드는 코드
  • 모델의 훈련된 가중치 또는 파라미터

이런 데이터를 공유하면 다른 사람들이 모델의 작동 방식을 이해하고 새로운 데이터로 모델을 실험하는데 도움이 됩니다.

주의: TensorFlow 모델은 코드이며 신뢰할 수 없는 코드에 주의하는 것이 중요합니다. 자세한 내용은 TensorFlow 안전하게 사용하기를 참조하세요.

저장 방식

사용 중인 API에 따라 TensorFlow 모델을 저장하는 다양한 방법이 있습니다. 이 가이드에서는 TensorFlow에서 모델을 빌드하고 훈련하는 고급 API인 tf.keras를 사용합니다. 다른 접근 방식에 대해서는 SavedModel 형식 사용 가이드Keras 모델 저장 및 로드 가이드를 참조하세요.

설정

설치와 임포트

필요한 라이브러리를 설치하고 텐서플로를 임포트(import)합니다:

pip install pyyaml h5py  # Required to save models in HDF5 format
import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)
2022-12-14 20:23:03.644578: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:23:03.644690: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:23:03.644701: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2.11.0

예제 데이터셋 받기

MNIST 데이터셋으로 모델을 훈련하여 가중치를 저장하는 예제를 만들어 보겠습니다. 모델 실행 속도를 빠르게 하기 위해 샘플에서 처음 1,000개만 사용겠습니다:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

모델 정의

먼저 간단한 모델을 하나 만들어 보죠.

# Define a simple sequential model
def create_model():
  model = tf.keras.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 512)               401920    
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

훈련하는 동안 체크포인트 저장하기

훈련된 모델을 다시 훈련할 필요 없이 사용하거나 훈련 과정이 중단된 경우 중단한 부분에서 훈련을 다시 시작할 수 있습니다. tf.keras.callbacks.ModelCheckpoint 콜백을 사용하면 훈련 도중 또는 훈련 종료 시 모델을 지속적으로 저장할 수 있습니다.

체크포인트 콜백 사용하기

훈련하는 동안 가중치를 저장하기 위해 ModelCheckpoint 콜백을 만들어 보죠:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10
22/32 [===================>..........] - ETA: 0s - loss: 1.4028 - sparse_categorical_accuracy: 0.5909 
Epoch 1: saving model to training_1/cp.ckpt
32/32 [==============================] - 2s 12ms/step - loss: 1.1664 - sparse_categorical_accuracy: 0.6690 - val_loss: 0.7275 - val_sparse_categorical_accuracy: 0.7790
Epoch 2/10
23/32 [====================>.........] - ETA: 0s - loss: 0.4540 - sparse_categorical_accuracy: 0.8723
Epoch 2: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.4351 - sparse_categorical_accuracy: 0.8760 - val_loss: 0.5317 - val_sparse_categorical_accuracy: 0.8380
Epoch 3/10
24/32 [=====================>........] - ETA: 0s - loss: 0.2708 - sparse_categorical_accuracy: 0.9284
Epoch 3: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.2764 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.4818 - val_sparse_categorical_accuracy: 0.8520
Epoch 4/10
24/32 [=====================>........] - ETA: 0s - loss: 0.2057 - sparse_categorical_accuracy: 0.9505
Epoch 4: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.2113 - sparse_categorical_accuracy: 0.9500 - val_loss: 0.4803 - val_sparse_categorical_accuracy: 0.8430
Epoch 5/10
23/32 [====================>.........] - ETA: 0s - loss: 0.1490 - sparse_categorical_accuracy: 0.9674
Epoch 5: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.1482 - sparse_categorical_accuracy: 0.9680 - val_loss: 0.4408 - val_sparse_categorical_accuracy: 0.8580
Epoch 6/10
23/32 [====================>.........] - ETA: 0s - loss: 0.1195 - sparse_categorical_accuracy: 0.9796
Epoch 6: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.1105 - sparse_categorical_accuracy: 0.9810 - val_loss: 0.4570 - val_sparse_categorical_accuracy: 0.8560
Epoch 7/10
23/32 [====================>.........] - ETA: 0s - loss: 0.0786 - sparse_categorical_accuracy: 0.9905
Epoch 7: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0857 - sparse_categorical_accuracy: 0.9840 - val_loss: 0.4329 - val_sparse_categorical_accuracy: 0.8720
Epoch 8/10
23/32 [====================>.........] - ETA: 0s - loss: 0.0713 - sparse_categorical_accuracy: 0.9891
Epoch 8: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0707 - sparse_categorical_accuracy: 0.9880 - val_loss: 0.4151 - val_sparse_categorical_accuracy: 0.8670
Epoch 9/10
23/32 [====================>.........] - ETA: 0s - loss: 0.0484 - sparse_categorical_accuracy: 0.9946
Epoch 9: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0481 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4199 - val_sparse_categorical_accuracy: 0.8740
Epoch 10/10
22/32 [===================>..........] - ETA: 0s - loss: 0.0325 - sparse_categorical_accuracy: 1.0000
Epoch 10: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0331 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4082 - val_sparse_categorical_accuracy: 0.8750
<keras.callbacks.History at 0x7fa5cc0b0a90>

이 코드는 텐서플로 체크포인트 파일을 만들고 에포크가 종료될 때마다 업데이트합니다:

os.listdir(checkpoint_dir)
['cp.ckpt.index', 'checkpoint', 'cp.ckpt.data-00000-of-00001']

두 모델이 동일한 아키텍처를 공유하기만 한다면 두 모델 간에 가중치를 공유할 수 있습니다. 따라서 가중치 전용에서 모델을 복원할 때 원래 모델과 동일한 아키텍처로 모델을 만든 다음 가중치를 설정합니다.

이제 훈련되지 않은 새로운 모델을 다시 빌드하고 테스트 세트에서 평가합니다. 훈련되지 않은 모델은 확률 수준(~10% 정확도)에서 수행됩니다.

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.4037 - sparse_categorical_accuracy: 0.0760 - 164ms/epoch - 5ms/step
Untrained model, accuracy:  7.60%

체크포인트에서 가중치를 로드하고 다시 평가해 보죠:

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4082 - sparse_categorical_accuracy: 0.8750 - 75ms/epoch - 2ms/step
Restored model, accuracy: 87.50%

체크포인트 콜백 매개변수

이 콜백 함수는 몇 가지 매개변수를 제공합니다. 체크포인트 이름을 고유하게 만들거나 체크포인트 주기를 조정할 수 있습니다.

새로운 모델을 훈련하고 다섯 번의 에포크마다 고유한 이름으로 체크포인트를 저장해 보겠습니다:

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*batch_size)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt

Epoch 10: saving model to training_2/cp-0010.ckpt

Epoch 15: saving model to training_2/cp-0015.ckpt

Epoch 20: saving model to training_2/cp-0020.ckpt

Epoch 25: saving model to training_2/cp-0025.ckpt

Epoch 30: saving model to training_2/cp-0030.ckpt

Epoch 35: saving model to training_2/cp-0035.ckpt

Epoch 40: saving model to training_2/cp-0040.ckpt

Epoch 45: saving model to training_2/cp-0045.ckpt

Epoch 50: saving model to training_2/cp-0050.ckpt
<keras.callbacks.History at 0x7fa5c0369fd0>

이제 결과로 나온 체크포인트를 검토하고 최신 체크포인트를 선택합니다.

os.listdir(checkpoint_dir)
['cp-0040.ckpt.data-00000-of-00001',
 'cp-0020.ckpt.index',
 'cp-0000.ckpt.data-00000-of-00001',
 'cp-0030.ckpt.index',
 'cp-0050.ckpt.index',
 'cp-0025.ckpt.index',
 'cp-0010.ckpt.data-00000-of-00001',
 'cp-0030.ckpt.data-00000-of-00001',
 'cp-0000.ckpt.index',
 'cp-0050.ckpt.data-00000-of-00001',
 'cp-0005.ckpt.index',
 'cp-0015.ckpt.data-00000-of-00001',
 'cp-0025.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.index',
 'cp-0015.ckpt.index',
 'cp-0005.ckpt.data-00000-of-00001',
 'checkpoint',
 'cp-0010.ckpt.index',
 'cp-0045.ckpt.index',
 'cp-0045.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.data-00000-of-00001',
 'cp-0020.ckpt.data-00000-of-00001',
 'cp-0040.ckpt.index']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

참고: 기본 TensorFlow 형식은 가장 최근의 체크포인트 5개만 저장합니다.

테스트를 위해 모델을 재설정하고 최신 체크포인트를 로드합니다.

# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4952 - sparse_categorical_accuracy: 0.8790 - 169ms/epoch - 5ms/step
Restored model, accuracy: 87.90%

이 파일들은 무엇인가요?

위의 코드는 이진 형식의 훈련된 가중치만 포함하는 체크포인트 형식의 파일 모음에 가중치를 저장합니다. 체크포인트에는 다음이 포함됩니다.

  • 모델의 가중치를 포함하는 하나 이상의 샤드
  • 어떤 가중치가 어떤 샤드에 저장되어 있는지 나타내는 인덱스 파일

단일 머신에서 모델을 훈련하는 경우 접미사가 .data-00000-of-00001인 하나의 샤드를 갖게 됩니다.

수동으로 가중치 저장하기

가중치를 수동으로 저장하려면 tf.keras.Model.save_weights를 사용합니다. 기본적으로 tf.keras, 그리고 특히 Model.save_weights 메서드는 .ckpt 확장자가 있는 TensorFlow 체크포인트 형식을 사용합니다. .h5 확장자를 사용하여 HDF5 형식으로 저장하려면 모델 저장 및 로드 가이드를 참조하세요.

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4952 - sparse_categorical_accuracy: 0.8790 - 166ms/epoch - 5ms/step
Restored model, accuracy: 87.90%

전체 모델 저장하기

tf.keras.Model.save를 호출하여 단일 file/folder에 모델의 아키텍처, 가중치 및 훈련 구성을 저장합니다. 이렇게 하면 원본 Python 코드에 액세스하지 않고도 사용할 수 있도록 모델을 내보낼 수 있습니다.* 옵티마이저 상태가 복구되므로 중단했던 지점부터 훈련을 재개할 수 있습니다.

전체 모델은 두 가지 다른 파일 형식(SavedModelHDF5)으로 저장할 수 있습니다. TensorFlow SavedModel 형식은 TF2.x의 기본 파일 형식입니다. 그러나 모델을 HDF5 형식으로 저장할 수 있습니다. 전체 모델을 두 가지 파일 형식으로 저장하는 방법에 대한 자세한 내용은 아래에 설명되어 있습니다.

완전한 기능의 모델을 저장하는 것은 매우 유용합니다. 즉, TensorFlow.js(저장된 모델, HDF5)에서 로드한 다음 웹 브라우저에서 훈련 및 실행하거나 TensorFlow Lite(저장된 모델, HDF5)를 사용하여 모바일 장치에서 실행되도록 변환할 수 있습니다.

*사용자 정의 객체(예: 하위 클래싱된 모델 또는 레이어)는 저장 및 로드할 때 특별한 주의가 필요합니다. 아래의 사용자 정의 객체 저장하기 섹션을 참조하세요.

SavedModel 포맷

SavedModel 형식은 모델을 직렬화하는 또 다른 방법입니다. 이 형식으로 저장된 모델은 tf.keras.models.load_model을 사용하여 복원할 수 있으며 TensorFlow Serving과 호환됩니다. SavedModel 가이드에 SavedModel을 제공/검사하는 방법이 자세히 설명되어 있습니다. 아래 섹션은 모델을 저장하고 복원하는 단계를 보여줍니다.

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8
Epoch 1/5
32/32 [==============================] - 1s 2ms/step - loss: 1.1550 - sparse_categorical_accuracy: 0.6920
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4072 - sparse_categorical_accuracy: 0.8950
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2911 - sparse_categorical_accuracy: 0.9220
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1990 - sparse_categorical_accuracy: 0.9580
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1525 - sparse_categorical_accuracy: 0.9630
INFO:tensorflow:Assets written to: saved_model/my_model/assets

SavedModel 형식은 protobuf 바이너리와 TensorFlow 체크포인트를 포함하는 디렉토리입니다. 저장된 모델 디렉토리를 검사합니다.

# my_model directory
ls saved_model

# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model
assets  fingerprint.pb  keras_metadata.pb  saved_model.pb  variables

저장된 모델로부터 새로운 케라스 모델을 로드합니다:

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_10 (Dense)            (None, 512)               401920    
                                                                 
 dropout_5 (Dropout)         (None, 512)               0         
                                                                 
 dense_11 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

복원된 모델은 원본 모델과 동일한 매개변수로 컴파일되어 있습니다. 이 모델을 평가하고 예측에 사용해 보죠:

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4686 - sparse_categorical_accuracy: 0.8520 - 174ms/epoch - 5ms/step
32/32 [==============================] - 0s 1ms/step
(1000, 10)

HDF5 파일로 저장하기

케라스는 HDF5 표준을 따르는 기본 저장 포맷을 제공합니다.

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5
32/32 [==============================] - 1s 3ms/step - loss: 1.1802 - sparse_categorical_accuracy: 0.6620
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4295 - sparse_categorical_accuracy: 0.8750
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2942 - sparse_categorical_accuracy: 0.9270
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2196 - sparse_categorical_accuracy: 0.9490
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1616 - sparse_categorical_accuracy: 0.9650

이제 이 파일로부터 모델을 다시 만들어 보죠:

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_12 (Dense)            (None, 512)               401920    
                                                                 
 dropout_6 (Dropout)         (None, 512)               0         
                                                                 
 dense_13 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

정확도를 확인해 보겠습니다:

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4311 - sparse_categorical_accuracy: 0.8640 - 171ms/epoch - 5ms/step
Restored model, accuracy: 86.40%

Keras는 아키텍처를 검사하여 모델을 저장합니다. 이 기술은 모든 내용을 저장합니다.

  • 가중치 값
  • 모델 구조
  • 모델의 훈련 구성(.compile() 메서드에 전달하는 내용)
  • 존재하는 옵티마이저와 그 상태(훈련을 중단한 곳에서 다시 시작할 수 있게 해줌)

체크포인트가 호환되지 않기 때문에 케라스는 v1.x 옵티마이저(tf.compat.v1.train)를 저장할 수 없습니다. v1.x 옵티마이저를 사용하려면 로드한 후에 모델을 다시 컴파일해야 합니다. 따라서 옵티마이저의 상태를 잃게 됩니다.

사용자 정의 객체

SavedModel 형식을 사용하는 경우, 이 섹션을 건너뛸 수 있습니다. HDF5와 SavedModel의 주요 차이점은 HDF5는 객체 구성을 사용하여 모델 아키텍처를 저장하는 반면, SavedModel은 실행 그래프를 저장한다는 것입니다. 따라서 SavedModel은 원본 코드 없이도 서브클래싱된 모델 및 사용자 지정 레이어와 같은 사용자 지정 객체를 저장할 수 있습니다.

사용자 정의 객체를 HDF5로 저장하려면 다음 과정을 따르세요:

  1. 이 객체에 get_config 메서드를 정의하고 선택적으로 from_config 클래스 메서드를 정의합니다.
    • get_config(self)는 객체를 다시 생성하기 위해 필요한 JSON 직렬화된 매개변수 딕셔너리를 반환합니다.
    • from_config(cls, config)get_config에서 반환된 설정을 사용해 새로운 객체를 만듭니다. 기본적으로 이 함수는 이 설정을 초기화 메서드의 매개변수로 사용합니다(return cls(**config)).
  2. 모델을 로드할 때 이 객체를 custom_objects 매개변수로 전달합니다. 문자열 클래스 이름과 파이썬 클래스를 매핑한 딕서너리를 매개변수로 제공해야 합니다. 예를 들면 tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})

사용자 정의 객체 및 get_config의 예를 보려면 처음부터 레이어 및 모델 작성하기 튜토리얼을 참조하세요.

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.