ML Community Day คือวันที่ 9 พฤศจิกายน! ร่วมกับเราสำหรับการปรับปรุงจาก TensorFlow, JAX และอื่น ๆ เรียนรู้เพิ่มเติม

บันทึกและโหลดโมเดลโดยใช้กลยุทธ์การกระจาย

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ภาพรวม

เป็นเรื่องปกติที่จะบันทึกและโหลดโมเดลระหว่างการฝึก API มีสองชุดสำหรับการบันทึกและโหลดโมเดล keras: API ระดับสูง และ API ระดับต่ำ บทช่วยสอนนี้สาธิตวิธีที่คุณสามารถใช้ SavedModel API เมื่อใช้ tf.distribute.Strategy หากต้องการเรียนรู้เกี่ยวกับ SavedModel และการทำให้เป็นอนุกรมโดยทั่วไป โปรดอ่าน คู่มือโมเดลที่บันทึกไว้ และคู่มือ การ สร้างอนุกรมโมเดล Keras มาเริ่มกันด้วยตัวอย่างง่ายๆ:

นำเข้าการพึ่งพา:

import tensorflow_datasets as tfds

import tensorflow as tf

เตรียมข้อมูลและโมเดลโดยใช้ tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

ฝึกโมเดล:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 10s 4ms/step - loss: 0.2033 - sparse_categorical_accuracy: 0.9408
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0644 - sparse_categorical_accuracy: 0.9812
<tensorflow.python.keras.callbacks.History at 0x7f07cc109790>

บันทึกและโหลดโมเดล

ในตอนนี้ คุณมีโมเดลที่ใช้งานง่ายแล้ว มาดู API การบันทึก/โหลดกัน API มีอยู่สองชุด:

Keras APIs

ต่อไปนี้คือตัวอย่างการบันทึกและโหลดโมเดลด้วย Keras APIs:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

คืนค่าโมเดลโดยไม่มี tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0482 - sparse_categorical_accuracy: 0.9849
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9896
<tensorflow.python.keras.callbacks.History at 0x7f08e8347e90>

หลังจากกู้คืนโมเดลแล้ว คุณสามารถฝึกต่อได้ แม้จะไม่จำเป็นต้องเรียก compile() อีกครั้ง เพราะมันคอมไพล์แล้วก่อนที่จะบันทึก โมเดลถูกบันทึกในรูปแบบโปรโต SavedModel มาตรฐานของ SavedModel สำหรับข้อมูลเพิ่มเติมโปรดดูที่ คู่มือเพื่อ saved_model รูปแบบ

ตอนนี้ให้โหลดโมเดลและฝึกฝนโดยใช้ tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9856
Epoch 2/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0337 - sparse_categorical_accuracy: 0.9898

อย่างที่คุณเห็น การโหลดใช้งานได้ตามที่คาดไว้ด้วย tf.distribute.Strategy กลยุทธ์ที่ใช้ในที่นี้ไม่จำเป็นต้องเป็นกลยุทธ์เดียวกับที่ใช้ก่อนบันทึก

tf.saved_model APIs

ทีนี้มาดูที่ API ระดับล่างกัน การบันทึกโมเดลคล้ายกับ keras API:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

การโหลดสามารถทำได้ด้วย tf.saved_model.load() อย่างไรก็ตาม เนื่องจากเป็น API ที่อยู่ในระดับที่ต่ำกว่า (และด้วยเหตุนี้จึงมีกรณีการใช้งานที่กว้างขึ้น) จึงไม่ส่งคืนโมเดล Keras แต่จะส่งคืนวัตถุที่มีฟังก์ชันที่สามารถใช้ในการอนุมานได้ ตัวอย่างเช่น:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

ออบเจ็กต์ที่โหลดอาจมีหลายฟังก์ชัน ซึ่งแต่ละอันเกี่ยวข้องกับคีย์ "serving_default" เป็นคีย์เริ่มต้นสำหรับฟังก์ชันการอนุมานด้วยโมเดล Keras ที่บันทึกไว้ ในการอนุมานด้วยฟังก์ชันนี้:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-3.87175977e-02, -1.61857940e-02,  4.99733090e-02,
         1.36069462e-01,  1.08132266e-01,  7.42500275e-02,
         7.73677081e-02, -4.71051671e-02,  6.41952083e-02,
         1.25855744e-01],
       [-6.54747933e-02,  2.14797929e-01, -8.99797529e-02,
        -7.83173367e-04,  8.92944187e-02,  1.73829526e-01,
         1.60048857e-01,  6.56097680e-02,  9.97588038e-03,
         1.09826915e-01],
       [-1.54433712e-01,  8.42373446e-03,  5.25225215e-02,
         4.81541678e-02,  2.23428048e-02,  7.69063532e-02,
        -3.99671011e-02, -4.56253998e-02,  5.28362878e-02,
         1.42020926e-01],
       [-9.43576694e-02, -9.26742554e-02,  7.42801651e-02,
         7.15272427e-02,  4.92905118e-02,  6.36474639e-02,
        -4.97807935e-03, -8.52353871e-02,  4.42789420e-02,
         1.56496197e-01],
       [-2.41690576e-02, -1.11253485e-01,  2.60174796e-02,
        -2.53655966e-02, -3.06888595e-02,  7.04614669e-02,
        -2.01677606e-02, -8.55238289e-02,  4.99680266e-02,
         1.76511839e-01],
       [-4.30272967e-02, -2.32194141e-02, -4.67026755e-02,
         5.49829267e-02,  3.00423354e-02,  1.36088327e-01,
         2.45837178e-02, -3.89227420e-02,  4.13372666e-02,
         7.04330802e-02],
       [-7.81589746e-03,  8.88681635e-02,  1.90377012e-02,
         2.47566961e-02,  6.73517808e-02,  9.59524438e-02,
         9.72996876e-02, -6.45323768e-02, -5.78478463e-02,
         1.67243421e-01],
       [ 2.59691626e-02,  2.44912747e-02,  2.06067227e-03,
         1.70634389e-02,  2.18513757e-02,  1.37382165e-01,
         2.51548998e-02,  9.26546752e-03,  9.15516615e-02,
         1.02677807e-01],
       [-1.27567410e-01,  2.68350840e-02,  1.16995871e-02,
         9.84568521e-02,  5.41514680e-02,  1.51651859e-01,
         4.64795753e-02, -2.91735604e-02,  1.25106841e-01,
         1.71144456e-01],
       [-8.05673450e-02,  2.94806808e-02, -1.10152550e-01,
        -7.07100704e-02, -1.24747120e-02,  1.40526205e-01,
        -1.23536028e-02, -5.82108870e-02,  7.77795464e-02,
        -2.34171711e-02],
       [ 6.20728806e-02,  2.07295828e-02, -6.43237084e-02,
        -2.05786452e-02,  4.67800628e-03,  8.61115605e-02,
         4.20134440e-02, -4.34638187e-02,  3.63051891e-03,
         7.10046589e-02],
       [-6.08086586e-04,  3.42310257e-02, -8.34690034e-02,
         9.08722728e-03,  1.05238423e-01,  1.47501156e-01,
         1.42161340e-01,  6.30808845e-02, -5.24729490e-04,
         1.36735335e-01],
       [-1.91496432e-01, -3.44831869e-03, -1.86069421e-02,
        -5.05818650e-02,  6.41704351e-03,  4.03927974e-02,
         2.21725628e-02, -1.62169725e-01,  1.76057145e-02,
         1.85961321e-01],
       [-2.49194503e-02, -2.50838101e-02, -3.03340796e-02,
         2.27186158e-02, -3.09922658e-02,  1.24549076e-01,
        -2.09298395e-02, -3.32995206e-02,  5.32348230e-02,
         1.50542200e-01],
       [-4.28871512e-02, -9.22830626e-02,  2.90450491e-02,
         1.12102680e-01,  3.04660778e-02,  7.70893693e-02,
         6.90545812e-02, -3.36272307e-02,  5.62672317e-02,
         2.48093754e-01],
       [-1.07569546e-02, -7.48349279e-02,  2.02673525e-02,
         8.59332830e-02,  1.17068969e-01,  7.48098046e-02,
         6.08244948e-02,  1.01280659e-02,  1.91551913e-02,
         2.11404741e-01],
       [ 7.62790143e-02,  2.29362398e-03, -5.91743030e-02,
         1.31748244e-02,  3.23122516e-02,  7.41088167e-02,
        -6.43679202e-02, -4.93042730e-03,  7.44977593e-02,
         1.33817106e-01],
       [-1.31239325e-01, -1.13535374e-02, -4.06663343e-02,
        -1.28620975e-02, -5.99850900e-04,  1.77248433e-01,
        -1.11874640e-02, -1.12869248e-01,  2.11616959e-02,
         1.17670909e-01],
       [-7.98342153e-02,  1.17193013e-02, -5.84214628e-02,
         4.45103496e-02, -2.35349871e-03,  9.78477970e-02,
        -1.31990746e-01, -9.62637737e-02, -4.68819216e-03,
         4.20698449e-02],
       [ 2.33993679e-02, -6.48612902e-02,  3.57650109e-02,
         4.70414013e-02, -9.87344980e-03,  8.05342346e-02,
         8.79075974e-02, -3.31879668e-02,  4.50132787e-03,
         1.80063233e-01],
       [-4.68377843e-02, -4.20477986e-03, -6.34167045e-02,
         1.02471206e-02,  7.62451440e-02,  4.69596758e-02,
         1.50972791e-02,  2.54708696e-02,  6.22914918e-02,
        -7.17046950e-03],
       [-7.52374083e-02, -5.75587116e-02,  5.98922297e-02,
        -2.73536984e-02, -9.00403410e-03,  1.24273732e-01,
         3.02318707e-02, -8.23312476e-02,  4.71908152e-02,
         6.14355356e-02],
       [-1.15195602e-01, -1.64229050e-02,  1.29237305e-03,
        -1.09596789e-01,  3.80243734e-02,  1.29902706e-01,
         2.34013163e-02, -1.05847850e-01,  5.61432168e-02,
         6.71504661e-02],
       [-6.54574633e-02,  1.53685883e-01, -9.75284129e-02,
         7.12753274e-03,  8.82450342e-02,  1.34998679e-01,
         1.54644206e-01,  6.79479688e-02, -1.72668919e-02,
        -9.18642431e-03],
       [-2.10184306e-01,  1.36045277e-01, -8.87301117e-02,
        -1.74187630e-01,  4.86502573e-02,  2.76032418e-01,
         6.05597720e-02, -7.86263272e-02,  1.03168473e-01,
        -1.37950957e-01],
       [-1.14111096e-01,  2.64648199e-02,  3.91238183e-02,
         1.27296187e-02,  1.35679528e-01,  1.77840948e-01,
        -1.35071129e-02,  2.69818418e-02,  3.42153087e-02,
         8.82407650e-02],
       [-3.01946551e-02, -2.40282975e-02, -6.38710558e-02,
        -3.98451164e-02, -2.92418208e-02,  1.03236027e-01,
        -6.42294213e-02, -6.60998672e-02,  8.01952481e-02,
         4.51336950e-02],
       [-7.84322768e-02,  3.11397053e-02, -5.89536875e-02,
         5.54301403e-03,  4.19510715e-02,  1.42857507e-01,
         3.19332965e-02,  8.02233815e-03,  4.79588211e-02,
         3.57780866e-02],
       [ 4.59235013e-02, -5.12123033e-02,  7.95385391e-02,
        -2.55374219e-02,  4.68068533e-02,  4.43704054e-02,
         7.18893707e-02,  1.64316930e-02,  1.46261796e-01,
         1.57745421e-01],
       [-7.41346031e-02, -2.22236216e-02,  5.15953489e-02,
         3.52352336e-02,  3.60111147e-02,  1.15427203e-01,
         2.57897153e-02, -1.20676845e-01,  5.01920506e-02,
         5.08756042e-02],
       [-8.63073617e-02,  3.17134708e-03,  7.40556046e-03,
        -4.21377048e-02, -1.40033364e-02,  1.12993240e-01,
         3.54410820e-02, -6.92673773e-03, -1.53519716e-02,
         1.26333967e-01],
       [-1.94949329e-01,  3.92537005e-02, -3.54437791e-02,
         2.31000036e-03,  3.79432067e-02,  1.76356524e-01,
        -2.87505463e-02, -1.43955618e-01,  5.22137731e-02,
         1.74010634e-01],
       [-7.06951618e-02,  4.39597517e-02, -6.54652417e-02,
         8.82245004e-02,  2.06859671e-02,  2.03700036e-01,
         4.97581288e-02, -9.99335721e-02,  1.11235633e-01,
         1.77167300e-02],
       [-3.53763551e-02,  1.12260692e-02, -1.11665837e-02,
         2.38924790e-02, -2.76037231e-02,  5.51242232e-02,
         1.64862610e-02, -7.48966262e-02,  8.12724680e-02,
         1.43957604e-03],
       [-6.88709617e-02,  8.31975043e-02, -1.29381031e-01,
         3.24329957e-02,  7.93049186e-02,  1.07140720e-01,
        -1.81627274e-03, -6.15758151e-02, -4.92713787e-03,
         1.36734053e-01],
       [-1.40681088e-01, -3.15807723e-02,  2.60454044e-02,
        -4.27230299e-02,  9.93238837e-02,  8.53562057e-02,
        -5.72501905e-02, -4.54951562e-02,  5.40010631e-02,
         1.60534859e-01],
       [-9.79049355e-02,  6.72584400e-03,  8.45115632e-03,
         6.96118921e-04,  6.83328062e-02,  7.55404532e-02,
         4.08902764e-02, -4.68667373e-02, -1.57676004e-02,
         1.63363129e-01],
       [ 8.52064192e-02, -4.55360785e-02, -1.45583451e-02,
         3.04546617e-02,  8.79652798e-04,  4.33428399e-02,
        -1.49330217e-02, -8.25205147e-02,  6.79182187e-02,
         1.38463050e-01],
       [-6.81671500e-03,  2.08939444e-02, -4.36149202e-02,
         4.71097752e-02, -4.12514098e-02,  7.67979622e-02,
        -3.79576795e-02, -5.68123832e-02,  8.52056891e-02,
         8.15685652e-03],
       [-3.65775973e-02, -3.06402445e-02,  6.65950254e-02,
         4.47403751e-02,  1.30895078e-01,  1.50094643e-01,
        -7.96291754e-02, -2.53046192e-02,  1.07343026e-01,
         1.23134412e-01],
       [-7.62542933e-02,  8.05391073e-02, -4.92790192e-02,
         9.18199718e-02, -2.56493650e-02,  1.11088946e-01,
         1.23991318e-01, -7.41875842e-02,  1.47743657e-01,
         1.12662911e-01],
       [-1.30924150e-01, -4.17556763e-02,  9.19794068e-02,
         1.42487824e-01,  5.94666004e-02,  8.11570883e-02,
         8.41576308e-02, -9.05135348e-02,  6.55059814e-02,
         1.40432447e-01],
       [-1.17066503e-02,  2.23872401e-02,  8.04524869e-02,
         4.16423976e-02,  2.44700648e-02,  1.15135796e-01,
        -4.52195890e-02, -3.85435522e-02,  5.72163910e-02,
        -1.70241687e-02],
       [-1.59098431e-01, -2.48743258e-02,  1.15704350e-03,
        -1.94135085e-02,  7.07440227e-02,  1.02726325e-01,
         7.48141706e-02, -4.92161885e-02, -4.82953712e-03,
         7.76504800e-02],
       [-5.23527563e-02, -6.95675313e-02,  6.23273998e-02,
         2.89142895e-02,  5.82050942e-02,  3.41961756e-02,
         5.93537465e-02, -4.88524139e-03,  6.03169389e-02,
         1.83362126e-01],
       [-1.74306586e-01,  8.50927830e-03,  5.59768602e-02,
         6.93600252e-03,  1.01455852e-01,  1.91212326e-01,
         1.01702012e-01, -4.06461619e-02,  7.65661225e-02,
         5.99906780e-02],
       [-1.77252740e-01, -8.87014568e-02,  2.68679634e-02,
        -9.68291797e-03,  3.35638374e-02,  5.05711734e-02,
        -8.12163353e-02, -1.41850561e-01,  1.26373529e-01,
         1.44208223e-01],
       [ 8.87322426e-03,  4.11978438e-02, -1.41141713e-02,
         4.23451513e-02,  1.00925714e-01,  1.82571277e-01,
         1.03720605e-01,  7.12942332e-03,  8.42842646e-03,
         1.58599243e-01],
       [-2.13406682e-02, -1.75910015e-02, -5.12900352e-02,
         2.48414166e-02,  1.27789266e-02,  1.57729238e-01,
        -9.69544426e-03, -8.99763852e-02,  1.03400812e-01,
         1.43311873e-01],
       [-8.07911009e-02, -7.29364436e-03,  8.38935375e-03,
        -2.07424574e-02,  5.76308332e-02,  1.60505801e-01,
         1.78787038e-02, -3.16766389e-02,  9.31937695e-02,
        -1.85308605e-02],
       [-9.09739435e-02,  9.12447274e-02, -2.94412468e-02,
        -4.54112887e-05,  5.61082549e-02,  1.87468201e-01,
         1.19883358e-01, -1.82710923e-02,  6.45193905e-02,
         2.49637775e-02],
       [-1.29598469e-01, -3.25308368e-03, -4.12912332e-02,
        -5.08361273e-02, -1.04823690e-02,  1.82079896e-02,
        -6.16563670e-02, -7.23700672e-02,  1.43699115e-02,
         1.14841349e-01],
       [-6.54844195e-02,  1.32181123e-02, -1.93353333e-02,
         3.21756229e-02,  1.93038825e-02,  6.58847988e-02,
        -8.00077915e-02, -8.77422094e-02,  6.08410016e-02,
         7.79580325e-02],
       [-1.41706705e-01, -2.03647800e-02, -7.27154315e-03,
         6.24460652e-02,  4.42662910e-02,  3.27514187e-02,
        -1.90984663e-02, -2.80730426e-02,  6.94331005e-02,
         1.08589470e-01],
       [-1.47680700e-01,  1.48287900e-02,  4.51873392e-02,
         9.84134525e-03, -3.12292501e-02,  7.53606334e-02,
        -3.28539237e-02, -8.99436623e-02,  9.34160203e-02,
         1.19600557e-01],
       [-1.06706396e-01, -7.91918635e-02, -3.81190814e-02,
         7.73421675e-03,  5.89197949e-02,  1.48285478e-01,
         7.33993948e-04, -7.37952963e-02,  7.03667626e-02,
         1.29369870e-02],
       [-1.18764386e-01,  6.67049065e-02, -1.84375793e-03,
         9.65175256e-02,  2.87358947e-02,  1.37036711e-01,
        -1.75398681e-02, -3.47253568e-02, -1.17472522e-02,
         1.64805949e-01],
       [-2.96032578e-02,  7.04797730e-02, -5.30908853e-02,
         3.15117575e-02,  1.10758804e-02,  1.52998209e-01,
        -4.00629006e-02, -8.53683874e-02,  8.19639489e-02,
         8.44380781e-02],
       [ 5.15967757e-02, -7.42607191e-03, -5.11791557e-03,
         6.67313561e-02,  5.96174449e-02, -7.76399858e-03,
         1.19835325e-01,  2.59960741e-02,  3.64723615e-02,
         2.43333012e-01],
       [-1.81520402e-01, -2.21860819e-02, -9.37249511e-04,
        -4.36494425e-02,  1.48944080e-01,  1.20623425e-01,
         9.27821100e-02,  8.10635090e-03,  1.05029367e-01,
         5.49212396e-02],
       [-1.62024647e-01, -7.55626708e-03,  8.90679806e-02,
         1.09557875e-01,  4.00451683e-02,  5.88794537e-02,
        -5.02957255e-02, -9.45830047e-02,  6.43425062e-02,
         6.01378679e-02],
       [ 2.44281888e-02, -2.44084001e-03,  7.41914511e-02,
         1.29617304e-01,  7.07155839e-03,  6.66829422e-02,
         2.52056420e-02, -9.07487422e-02,  5.63489906e-02,
         1.43779904e-01],
       [ 4.61201221e-02, -6.86585605e-02, -7.71630462e-03,
        -2.53528468e-02,  1.99609250e-02,  9.59918946e-02,
         1.04020424e-02, -5.57698309e-02,  2.24557221e-02,
         1.73324734e-01],
       [-8.96764472e-02, -2.28214562e-02,  1.30368788e-02,
         6.01692796e-02,  1.63236298e-02,  9.95141268e-02,
        -1.29274517e-01, -2.46923715e-02,  3.26206349e-02,
         8.57910439e-02]], dtype=float32)>}

คุณยังสามารถโหลดและทำการอนุมานในลักษณะกระจาย:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

การเรียกใช้ฟังก์ชันที่กู้คืนเป็นเพียงการส่งต่อบนโมเดลที่บันทึกไว้ (คาดการณ์) จะทำอย่างไรถ้าคุณต้องการฝึกฟังก์ชั่นโหลดต่อไป? หรือฝังฟังก์ชั่นโหลดลงในรุ่นที่ใหญ่กว่า? แนวทางปฏิบัติทั่วไปคือการห่อวัตถุที่โหลดนี้ไปยังเลเยอร์ Keras เพื่อให้บรรลุเป้าหมายนี้ โชคดีที่ TF Hub มี hub.KerasLayer เพื่อจุดประสงค์นี้ แสดงไว้ที่นี่:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.1955 - sparse_categorical_accuracy: 0.9423
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0628 - sparse_categorical_accuracy: 0.9815

อย่างที่คุณเห็น hub.KerasLayer รวมผลลัพธ์ที่โหลดกลับมาจาก tf.saved_model.load() ลงในเลเยอร์ Keras ที่สามารถใช้สร้างโมเดลอื่นได้ สิ่งนี้มีประโยชน์มากสำหรับการถ่ายโอนการเรียนรู้

ฉันควรใช้ API ใด

สำหรับการบันทึก หากคุณกำลังทำงานกับโมเดล keras ขอแนะนำให้ใช้ API model.save() ของ model.save() เกือบทุกครั้ง หากสิ่งที่คุณกำลังบันทึกไม่ใช่โมเดล Keras ดังนั้น API ระดับล่างจะเป็นทางเลือกเดียวของคุณ

สำหรับการโหลด API ที่คุณใช้จะขึ้นอยู่กับสิ่งที่คุณต้องการได้รับจากการโหลด API หากคุณไม่สามารถ (หรือไม่ต้องการ) รับโมเดล Keras ให้ใช้ tf.saved_model.load() มิฉะนั้น ให้ใช้ tf.keras.models.load_model() โปรดทราบว่าคุณจะได้รับโมเดล Keras กลับมาก็ต่อเมื่อคุณบันทึกโมเดล Keras ไว้เท่านั้น

สามารถผสมและจับคู่ API ได้ คุณสามารถบันทึกโมเดล Keras ด้วย model.save และโหลดโมเดลที่ไม่ใช่ Keras ด้วย API ระดับต่ำ tf.saved_model.load

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

กำลังบันทึก/กำลังโหลดจากอุปกรณ์ในเครื่อง

เมื่อบันทึกและโหลดจากอุปกรณ์ io ในพื้นที่ขณะเรียกใช้จากระยะไกล เช่น การใช้ cloud TPU ต้องใช้ตัวเลือก experimental_io_device เพื่อตั้งค่าอุปกรณ์ io เป็น localhost

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

คำเตือน

กรณีพิเศษคือเมื่อคุณมีโมเดล Keras ที่ไม่มีอินพุตที่กำหนดไว้อย่างดี ตัวอย่างเช่น สามารถสร้างโมเดลตามลำดับโดยไม่มีรูปร่างอินพุต ( Sequential([Dense(3), ...] ) โมเดลย่อยก็ไม่มีอินพุตที่กำหนดไว้อย่างดีหลังจากการเริ่มต้น ในกรณีนี้ คุณควรยึดติดกับ API ระดับล่างทั้งการบันทึกและการโหลด มิฉะนั้น คุณจะได้รับข้อผิดพลาด

หากต้องการตรวจสอบว่าโมเดลของคุณมีอินพุตที่กำหนดไว้อย่างดีหรือไม่ ให้ตรวจสอบว่า model.inputs เป็น None ถ้าไม่ใช่ None ว่าคุณสบายดี รูปร่างอินพุตถูกกำหนดโดยอัตโนมัติเมื่อใช้โมเดลใน . .fit , .evaluate , .predict หรือเมื่อเรียกใช้โมเดล ( model(inputs) )

นี่คือตัวอย่าง:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets