บันทึกและโหลดโมเดล

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

สามารถบันทึกความคืบหน้าของแบบจำลองได้ในระหว่างและหลังการฝึก ซึ่งหมายความว่าโมเดลสามารถกลับมาทำงานต่อจากที่ค้างไว้และหลีกเลี่ยงการฝึกอบรมที่ยาวนาน การบันทึกยังหมายความว่าคุณสามารถแบ่งปันแบบจำลองของคุณและคนอื่นๆ สามารถสร้างงานของคุณขึ้นมาใหม่ได้ เมื่อเผยแพร่แบบจำลองและเทคนิคการวิจัย ผู้ปฏิบัติงานการเรียนรู้ของเครื่องส่วนใหญ่จะแบ่งปัน:

  • รหัสเพื่อสร้างแบบจำลองและ
  • ตุ้มน้ำหนักที่ฝึกหรือพารามิเตอร์สำหรับรุ่น

การแบ่งปันข้อมูลนี้จะช่วยให้ผู้อื่นเข้าใจวิธีการทำงานของแบบจำลองและลองใช้ข้อมูลใหม่ด้วยตนเอง

ตัวเลือก

มีหลายวิธีในการบันทึกโมเดล TensorFlow ขึ้นอยู่กับ API ที่คุณใช้ คู่มือนี้ใช้ tf.keras ซึ่งเป็น API ระดับสูงเพื่อสร้างและฝึกโมเดลใน TensorFlow สำหรับวิธีการอื่นๆ โปรดดูคู่มือ TensorFlow Save and Restore หรือ Saving inความกระตือรือร้น

ติดตั้ง

ติดตั้งและนำเข้า

ติดตั้งและนำเข้า TensorFlow และการพึ่งพา:

pip install pyyaml h5py  # Required to save models in HDF5 format
import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)
2.8.0-rc1

รับตัวอย่างชุดข้อมูล

ในการสาธิตวิธีบันทึกและโหลดตุ้มน้ำหนัก คุณจะต้องใช้ ชุดข้อมูล MNIST หากต้องการเร่งความเร็วการวิ่งเหล่านี้ ให้ใช้ 1,000 ตัวอย่างแรก:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

กำหนดรูปแบบ

เริ่มต้นด้วยการสร้างแบบจำลองลำดับอย่างง่าย:

# Define a simple sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 512)               401920    
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

บันทึกจุดตรวจระหว่างการฝึก

คุณสามารถใช้แบบจำลองที่ได้รับการฝึกอบรมโดยไม่ต้องฝึกใหม่ หรือรับการฝึกอบรมที่ค้างไว้ในกรณีที่กระบวนการฝึกอบรมถูกขัดจังหวะ การเรียกกลับ tf.keras.callbacks.ModelCheckpoint ช่วยให้คุณบันทึกโมเดลได้อย่างต่อเนื่องทั้ง ในระหว่าง และเมื่อ สิ้นสุด การฝึก

การใช้งานโทรกลับจุดตรวจ

สร้างการเรียกกลับ tf.keras.callbacks.ModelCheckpoint ที่ช่วยประหยัดน้ำหนักระหว่างการฝึกเท่านั้น:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10
23/32 [====================>.........] - ETA: 0s - loss: 1.3666 - sparse_categorical_accuracy: 0.6060 
Epoch 1: saving model to training_1/cp.ckpt
32/32 [==============================] - 1s 10ms/step - loss: 1.1735 - sparse_categorical_accuracy: 0.6690 - val_loss: 0.7180 - val_sparse_categorical_accuracy: 0.7750
Epoch 2/10
24/32 [=====================>........] - ETA: 0s - loss: 0.4238 - sparse_categorical_accuracy: 0.8789
Epoch 2: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.4201 - sparse_categorical_accuracy: 0.8810 - val_loss: 0.5621 - val_sparse_categorical_accuracy: 0.8150
Epoch 3/10
24/32 [=====================>........] - ETA: 0s - loss: 0.2795 - sparse_categorical_accuracy: 0.9336
Epoch 3: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.4790 - val_sparse_categorical_accuracy: 0.8430
Epoch 4/10
24/32 [=====================>........] - ETA: 0s - loss: 0.2027 - sparse_categorical_accuracy: 0.9427
Epoch 4: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.2016 - sparse_categorical_accuracy: 0.9440 - val_loss: 0.4361 - val_sparse_categorical_accuracy: 0.8610
Epoch 5/10
24/32 [=====================>........] - ETA: 0s - loss: 0.1739 - sparse_categorical_accuracy: 0.9583
Epoch 5: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.1683 - sparse_categorical_accuracy: 0.9610 - val_loss: 0.4640 - val_sparse_categorical_accuracy: 0.8580
Epoch 6/10
23/32 [====================>.........] - ETA: 0s - loss: 0.1116 - sparse_categorical_accuracy: 0.9796
Epoch 6: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.1125 - sparse_categorical_accuracy: 0.9780 - val_loss: 0.4420 - val_sparse_categorical_accuracy: 0.8580
Epoch 7/10
24/32 [=====================>........] - ETA: 0s - loss: 0.0978 - sparse_categorical_accuracy: 0.9831
Epoch 7: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0989 - sparse_categorical_accuracy: 0.9820 - val_loss: 0.4163 - val_sparse_categorical_accuracy: 0.8590
Epoch 8/10
21/32 [==================>...........] - ETA: 0s - loss: 0.0669 - sparse_categorical_accuracy: 0.9911
Epoch 8: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4411 - val_sparse_categorical_accuracy: 0.8600
Epoch 9/10
22/32 [===================>..........] - ETA: 0s - loss: 0.0495 - sparse_categorical_accuracy: 0.9972
Epoch 9: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0516 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.4064 - val_sparse_categorical_accuracy: 0.8650
Epoch 10/10
24/32 [=====================>........] - ETA: 0s - loss: 0.0436 - sparse_categorical_accuracy: 0.9948
Epoch 10: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0437 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4061 - val_sparse_categorical_accuracy: 0.8770
<keras.callbacks.History at 0x7eff8d865390>

สิ่งนี้จะสร้างคอลเล็กชันไฟล์จุดตรวจสอบ TensorFlow หนึ่งชุดที่อัปเดตเมื่อสิ้นสุดแต่ละยุค:

os.listdir(checkpoint_dir)
['checkpoint', 'cp.ckpt.index', 'cp.ckpt.data-00000-of-00001']

ตราบใดที่ทั้งสองรุ่นมีสถาปัตยกรรมเดียวกัน คุณก็สามารถแบ่งน้ำหนักระหว่างกันได้ ดังนั้น เมื่อกู้คืนแบบจำลองจากเฉพาะน้ำหนัก ให้สร้างแบบจำลองที่มีสถาปัตยกรรมเดียวกันกับรุ่นดั้งเดิม แล้วตั้งค่าน้ำหนัก

ตอนนี้สร้างแบบจำลองใหม่ที่ยังไม่ผ่านการฝึกอบรมและประเมินในชุดทดสอบ โมเดลที่ไม่ได้รับการฝึกฝนจะแสดงที่ระดับโอกาส (ความแม่นยำประมาณ 10%):

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.4473 - sparse_categorical_accuracy: 0.0980 - 145ms/epoch - 5ms/step
Untrained model, accuracy:  9.80%

จากนั้นโหลดน้ำหนักจากจุดตรวจและประเมินใหม่:

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4061 - sparse_categorical_accuracy: 0.8770 - 65ms/epoch - 2ms/step
Restored model, accuracy: 87.70%

ตัวเลือกการโทรกลับจุดตรวจ

การโทรกลับมีตัวเลือกมากมายในการระบุชื่อที่ไม่ซ้ำสำหรับจุดตรวจและปรับความถี่ของจุดตรวจ

ฝึกโมเดลใหม่และบันทึกจุดตรวจที่มีชื่อไม่ซ้ำกันทุกๆ ห้ายุค:

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*batch_size)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt

Epoch 10: saving model to training_2/cp-0010.ckpt

Epoch 15: saving model to training_2/cp-0015.ckpt

Epoch 20: saving model to training_2/cp-0020.ckpt

Epoch 25: saving model to training_2/cp-0025.ckpt

Epoch 30: saving model to training_2/cp-0030.ckpt

Epoch 35: saving model to training_2/cp-0035.ckpt

Epoch 40: saving model to training_2/cp-0040.ckpt

Epoch 45: saving model to training_2/cp-0045.ckpt

Epoch 50: saving model to training_2/cp-0050.ckpt
<keras.callbacks.History at 0x7eff807703d0>

ตอนนี้ให้ดูที่จุดตรวจที่เกิดขึ้นและเลือกจุดล่าสุด:

os.listdir(checkpoint_dir)
['cp-0005.ckpt.data-00000-of-00001',
 'cp-0050.ckpt.index',
 'checkpoint',
 'cp-0010.ckpt.index',
 'cp-0035.ckpt.data-00000-of-00001',
 'cp-0000.ckpt.data-00000-of-00001',
 'cp-0050.ckpt.data-00000-of-00001',
 'cp-0010.ckpt.data-00000-of-00001',
 'cp-0020.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.index',
 'cp-0040.ckpt.index',
 'cp-0025.ckpt.data-00000-of-00001',
 'cp-0045.ckpt.index',
 'cp-0020.ckpt.index',
 'cp-0025.ckpt.index',
 'cp-0030.ckpt.data-00000-of-00001',
 'cp-0030.ckpt.index',
 'cp-0000.ckpt.index',
 'cp-0045.ckpt.data-00000-of-00001',
 'cp-0015.ckpt.index',
 'cp-0015.ckpt.data-00000-of-00001',
 'cp-0005.ckpt.index',
 'cp-0040.ckpt.data-00000-of-00001']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

หากต้องการทดสอบ ให้รีเซ็ตโมเดลและโหลดจุดตรวจสอบล่าสุด:

# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 150ms/epoch - 5ms/step
Restored model, accuracy: 87.70%
ตัวยึดตำแหน่ง22

ไฟล์เหล่านี้คืออะไร?

รหัสด้านบนเก็บน้ำหนักไว้ในคอลเลกชันของไฟล์ที่จัดรูปแบบ จุดตรวจสอบ ที่มีเฉพาะน้ำหนักที่ฝึกแล้วในรูปแบบไบนารี จุดตรวจประกอบด้วย:

  • ชาร์ดอย่างน้อยหนึ่งรายการที่มีตุ้มน้ำหนักของโมเดลของคุณ
  • ไฟล์ดัชนีที่ระบุว่าน้ำหนักใดถูกเก็บไว้ในชาร์ดใด

หากคุณกำลังฝึกโมเดลในเครื่องเดียว คุณจะมีชาร์ดหนึ่งส่วนที่มีส่วนต่อท้าย: .data-00000-of-00001

ลดน้ำหนักด้วยตนเอง

การบันทึกน้ำหนักด้วยตนเองด้วยเมธอด Model.save_weights โดยค่าเริ่มต้น tf.keras — และโดยเฉพาะ save_weights ใช้รูปแบบ จุดตรวจสอบ TensorFlow ที่มีนามสกุล .ckpt (การบันทึกใน HDF5 ด้วยนามสกุล .h5 จะครอบคลุมอยู่ในคู่มือ บันทึกและกำหนดรูปแบบอนุกรม ):

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 143ms/epoch - 4ms/step
Restored model, accuracy: 87.70%

บันทึกโมเดลทั้งหมด

เรียก model.save เพื่อบันทึกสถาปัตยกรรม น้ำหนัก และการกำหนดค่าการฝึกของโมเดลในไฟล์/โฟลเดอร์เดียว สิ่งนี้ทำให้คุณสามารถเอ็กซ์พอร์ตโมเดลเพื่อให้ใช้งานได้โดยไม่ต้องเข้าถึงโค้ด Python ดั้งเดิม* เนื่องจากสถานะเครื่องมือเพิ่มประสิทธิภาพถูกกู้คืน คุณจึงสามารถกลับมาฝึกต่อได้จากจุดที่ค้างไว้

โมเดลทั้งหมดสามารถบันทึกได้ในรูปแบบไฟล์ที่แตกต่างกันสองรูปแบบ ( SavedModel และ HDF5 ) รูปแบบ SavedModel เป็นรูปแบบไฟล์เริ่มต้นใน TF2.x อย่างไรก็ตาม สามารถบันทึกโมเดลต่างๆ ในรูปแบบ HDF5 ได้ รายละเอียดเพิ่มเติมเกี่ยวกับการบันทึกโมเดลทั้งหมดในรูปแบบไฟล์ทั้งสองมีอธิบายไว้ด้านล่าง

การบันทึกโมเดลที่ใช้งานได้อย่างสมบูรณ์มีประโยชน์มาก คุณสามารถโหลดโมเดลเหล่านี้ใน TensorFlow.js ( Saved Model , HDF5 ) จากนั้นฝึกและเรียกใช้ในเว็บเบราว์เซอร์ หรือแปลงให้ทำงานบนอุปกรณ์มือถือโดยใช้ TensorFlow Lite ( Saved Model , HDF5 )

*อ็อบเจ็กต์ที่กำหนดเอง (เช่น โมเดลหรือเลเยอร์ย่อย) ต้องให้ความสนใจเป็นพิเศษเมื่อทำการบันทึกและโหลด ดูส่วนการ บันทึกวัตถุที่กำหนดเอง ด้านล่าง

รูปแบบโมเดลที่บันทึกไว้

รูปแบบ SavedModel เป็นอีกวิธีหนึ่งในการทำให้โมเดลเป็นอนุกรม โมเดลที่บันทึกในรูปแบบนี้สามารถกู้คืนได้โดยใช้ tf.keras.models.load_model และเข้ากันได้กับ TensorFlow Serving คู่มือ SavedModel จะลงรายละเอียดเกี่ยวกับวิธีการให้บริการ/ตรวจสอบ SavedModel ส่วนด้านล่างแสดงขั้นตอนในการบันทึกและกู้คืนโมเดล

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1988 - sparse_categorical_accuracy: 0.6550
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4180 - sparse_categorical_accuracy: 0.8930
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2900 - sparse_categorical_accuracy: 0.9220
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2070 - sparse_categorical_accuracy: 0.9540
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1593 - sparse_categorical_accuracy: 0.9630
2022-01-26 07:30:22.888387: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
INFO:tensorflow:Assets written to: saved_model/my_model/assets

รูปแบบ SavedModel เป็นไดเร็กทอรีที่มีไบนารี protobuf และจุดตรวจสอบ TensorFlow ตรวจสอบไดเร็กทอรีโมเดลที่บันทึกไว้:

# my_model directory
ls saved_model

# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model
assets  keras_metadata.pb  saved_model.pb  variables

โหลดโมเดล Keras ใหม่จากโมเดลที่บันทึกไว้:

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_10 (Dense)            (None, 512)               401920    
                                                                 
 dropout_5 (Dropout)         (None, 512)               0         
                                                                 
 dense_11 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

โมเดลที่กู้คืนถูกคอมไพล์ด้วยอาร์กิวเมนต์เดียวกันกับโมเดลดั้งเดิม ลองรันประเมินและคาดการณ์ด้วยโมเดลที่โหลด:

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4577 - sparse_categorical_accuracy: 0.8430 - 156ms/epoch - 5ms/step
Restored model, accuracy: 84.30%
(1000, 10)
ตัวยึดตำแหน่ง32

รูปแบบ HDF5

Keras จัดเตรียมรูปแบบการบันทึกพื้นฐานโดยใช้มาตรฐาน HDF5

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1383 - sparse_categorical_accuracy: 0.6970
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4094 - sparse_categorical_accuracy: 0.8920
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2936 - sparse_categorical_accuracy: 0.9160
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2050 - sparse_categorical_accuracy: 0.9460
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1485 - sparse_categorical_accuracy: 0.9690

ตอนนี้ สร้างโมเดลใหม่จากไฟล์นั้น:

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_12 (Dense)            (None, 512)               401920    
                                                                 
 dropout_6 (Dropout)         (None, 512)               0         
                                                                 
 dense_13 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
ตัวยึดตำแหน่ง36

ตรวจสอบความถูกต้อง:

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4266 - sparse_categorical_accuracy: 0.8620 - 141ms/epoch - 4ms/step
Restored model, accuracy: 86.20%

Keras บันทึกโมเดลด้วยการตรวจสอบสถาปัตยกรรมของพวกเขา เทคนิคนี้บันทึกทุกอย่าง:

  • ค่าน้ำหนัก
  • สถาปัตยกรรมของโมเดล
  • การกำหนดค่าการฝึกของโมเดล (สิ่งที่คุณส่งผ่านไปยัง .compile() )
  • เครื่องมือเพิ่มประสิทธิภาพและสถานะของอุปกรณ์ (หากมี) (ซึ่งจะทำให้คุณสามารถเริ่มการฝึกใหม่จากจุดที่ค้างไว้ได้)

Keras ไม่สามารถบันทึกเครื่องมือเพิ่มประสิทธิภาพ v1.x (จาก tf.compat.v1.train ) เนื่องจากไม่สามารถทำงานร่วมกับจุดตรวจได้ สำหรับตัวเพิ่มประสิทธิภาพ v1.x คุณต้องคอมไพล์โมเดลใหม่หลังจากโหลด—สูญเสียสถานะของตัวเพิ่มประสิทธิภาพ

กำลังบันทึกวัตถุที่กำหนดเอง

หากคุณกำลังใช้รูปแบบ SavedModel คุณสามารถข้ามส่วนนี้ได้ ความแตกต่างที่สำคัญระหว่าง HDF5 และ SavedModel คือ HDF5 ใช้การกำหนดค่าอ็อบเจ็กต์เพื่อบันทึกสถาปัตยกรรมโมเดล ในขณะที่ SavedModel จะบันทึกกราฟการดำเนินการ ดังนั้น SavedModels จึงสามารถบันทึกอ็อบเจ็กต์ที่กำหนดเองได้ เช่น โมเดลย่อยและเลเยอร์ที่กำหนดเองโดยไม่ต้องใช้โค้ดต้นฉบับ

ในการบันทึกออบเจ็กต์ที่กำหนดเองไปยัง HDF5 คุณต้องทำดังต่อไปนี้:

  1. กำหนดเมธอด get_config ในอ็อบเจ็กต์ของคุณ และเป็นทางเลือก from_config classmethod
    • get_config(self) ส่งคืนพจนานุกรมพารามิเตอร์ JSON-serializable ที่จำเป็นในการสร้างวัตถุขึ้นใหม่
    • from_config(cls, config) ใช้การกำหนดค่าที่ส่งคืนจาก get_config เพื่อสร้างวัตถุใหม่ โดยค่าเริ่มต้น ฟังก์ชันนี้จะใช้การกำหนดค่าเป็นค่าเริ่มต้น kwargs ( return cls(**config) )
  2. ส่งอ็อบเจ็กต์ไปยังอาร์กิวเมนต์ custom_objects เมื่อโหลดโมเดล อาร์กิวเมนต์ต้องเป็นพจนานุกรมที่จับคู่ชื่อคลาสสตริงกับคลาส Python เช่น tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})

ดูบทแนะนำ การเขียนเลเยอร์และโมเดลตั้งแต่เริ่มต้น สำหรับตัวอย่างของอ็อบเจกต์ที่กำหนดเองและ get_config

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.