หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

บันทึกและโหลดโมเดลโดยใช้กลยุทธ์การกระจาย

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

ภาพรวม

เป็นเรื่องปกติที่จะบันทึกและโหลดโมเดลระหว่างการฝึกอบรม มี API สองชุดสำหรับบันทึกและโหลดโมเดล Keras ได้แก่ API ระดับสูงและ API ระดับต่ำ บทช่วยสอนนี้แสดงให้เห็นว่าคุณสามารถใช้ SavedModel API ได้อย่างไรเมื่อใช้ tf.distribute.Strategy หากต้องการเรียนรู้เกี่ยวกับ SavedModel และการทำให้เป็นอนุกรมโดยทั่วไปโปรดอ่าน คู่มือโมเดลที่บันทึกไว้ และ คู่มือการจัดลำดับโมเดล Keras เริ่มจากตัวอย่างง่ายๆ:

นำเข้าการอ้างอิง:

import tensorflow_datasets as tfds

import tensorflow as tf

เตรียมข้อมูลและโมเดลโดยใช้ tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

ฝึกโมเดล:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 4s 4ms/step - loss: 0.2095 - sparse_categorical_accuracy: 0.9386
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0730 - sparse_categorical_accuracy: 0.9787

<tensorflow.python.keras.callbacks.History at 0x7f7470042b38>

บันทึกและโหลดโมเดล

ตอนนี้คุณมีโมเดลง่ายๆที่จะใช้งานได้แล้วมาดู API การบันทึก / การโหลดกัน มี API สองชุดที่พร้อมใช้งาน:

Keras API

นี่คือตัวอย่างของการบันทึกและโหลดโมเดลด้วย Keras APIs:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

กู้คืนโมเดลโดยไม่ใช้ tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0539 - sparse_categorical_accuracy: 0.9838
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0381 - sparse_categorical_accuracy: 0.9884

<tensorflow.python.keras.callbacks.History at 0x7f74d333f780>

หลังจากกู้คืนโมเดลแล้วคุณสามารถฝึกต่อได้โดยไม่จำเป็นต้องเรียก compile() อีกครั้งเนื่องจากคอมไพล์แล้วก่อนบันทึก โมเดลจะถูกบันทึกในรูปแบบโปรโต SavedModel มาตรฐานของ SavedModel สำหรับข้อมูลเพิ่มเติมโปรดดูที่ คู่มือเพื่อ saved_model รูปแบบ

ตอนนี้เพื่อโหลดโมเดลและฝึกโดยใช้ tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 10ms/step - loss: 0.0530 - sparse_categorical_accuracy: 0.9844
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0388 - sparse_categorical_accuracy: 0.9882

อย่างที่คุณเห็นการโหลดทำงานตามที่คาดไว้กับ tf.distribute.Strategy กลยุทธ์ที่ใช้ในที่นี้ไม่จำเป็นต้องเป็นกลยุทธ์เดียวกับที่ใช้ก่อนการบันทึก

tf.saved_model API

ตอนนี้เรามาดู API ระดับล่างกัน การบันทึกโมเดลนั้นคล้ายกับ keras API:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

การโหลดสามารถทำได้ด้วย tf.saved_model.load() อย่างไรก็ตามเนื่องจากเป็น API ที่อยู่ในระดับล่าง (และด้วยเหตุนี้จึงมีกรณีการใช้งานที่กว้างขึ้น) จึงไม่ส่งคืนโมเดล Keras แต่จะส่งคืนวัตถุที่มีฟังก์ชันที่สามารถใช้ในการอนุมานได้ ตัวอย่างเช่น:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

อ็อบเจ็กต์ที่โหลดอาจมีหลายฟังก์ชันซึ่งแต่ละฟังก์ชันเกี่ยวข้องกับคีย์ "serving_default" เป็นคีย์เริ่มต้นสำหรับฟังก์ชันการอนุมานด้วยโมเดล Keras ที่บันทึกไว้ ในการอนุมานด้วยฟังก์ชันนี้:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.46878400e-01, -2.84028575e-02,  4.34195548e-02,
         8.65758881e-02, -5.50181568e-02, -2.26117969e-02,
        -8.18806365e-02,  1.60868585e-01,  7.05277026e-02,
        -2.11526364e-01],
       [-2.04405725e-01, -2.38965377e-02,  1.06097549e-01,
         1.15776211e-02, -5.68305999e-02,  7.61558264e-02,
        -2.36685127e-02,  6.12710230e-02,  6.85455352e-02,
        -2.04084530e-01],
       [-1.70060426e-01,  6.82905912e-02, -2.54967008e-02,
         1.27377272e-01, -4.24135383e-03, -1.15118716e-02,
         1.65115029e-01,  1.64797649e-01,  8.41001868e-02,
        -2.60865986e-01],
       [-1.24608956e-01,  7.05861971e-02,  4.76837084e-02,
         9.51382518e-02, -1.36017501e-02,  9.53883678e-02,
        -2.60323286e-04,  1.26946449e-01, -9.98851806e-02,
         6.01550192e-02],
       [-8.42214674e-02, -4.93131615e-02, -5.85474074e-04,
        -3.79234888e-02, -6.78482801e-02,  9.56373289e-02,
         4.69041206e-02,  8.55031833e-02,  9.31831449e-02,
        -1.40825540e-01],
       [-1.46941900e-01,  1.22972876e-02,  5.79140112e-02,
        -7.50405565e-02,  6.13511279e-02,  1.14746153e-01,
         3.54535617e-02,  2.55915433e-01,  7.26796240e-02,
        -1.99857190e-01],
       [-2.07879156e-01,  1.83034241e-02,  1.57775074e-01,
         6.06807172e-02, -1.75382420e-02,  1.33817732e-01,
         1.36331618e-01,  2.02472329e-01,  3.72610986e-02,
        -1.31865010e-01],
       [-9.93705392e-02,  6.03869818e-02, -4.28698361e-02,
         6.31842762e-04,  8.84034038e-02,  6.72685653e-02,
        -2.09506359e-02,  1.97081745e-01,  7.39021823e-02,
        -1.64300233e-01],
       [-9.71228778e-02,  5.48233166e-02,  1.38393641e-02,
        -7.14895800e-02, -3.87909710e-02,  8.45830888e-04,
        -3.62640694e-02,  1.64835989e-01,  5.04231751e-02,
        -2.07461655e-01],
       [-2.92240772e-02,  1.45425312e-02,  5.74428178e-02,
        -1.34241190e-02, -1.80013701e-02,  7.78546855e-02,
        -8.48746449e-02,  9.98296142e-02,  6.38790280e-02,
        -5.32845445e-02],
       [-1.76605240e-01, -1.42511949e-01,  1.39559209e-01,
        -2.00123414e-02, -6.44349307e-02, -4.56911251e-02,
         2.01093405e-03,  1.59898788e-01,  1.95391588e-02,
        -1.61375850e-01],
       [-1.58091724e-01,  6.25609234e-03,  2.12391287e-01,
        -1.39106885e-01, -4.78955358e-02,  7.36434534e-02,
         7.29984716e-02,  2.28351891e-01,  1.23042218e-01,
        -2.22285807e-01],
       [-6.63312748e-02, -5.25613949e-02,  3.88407931e-02,
         4.74876724e-02, -3.56937200e-02,  1.11578718e-01,
        -8.47167745e-02,  1.54049486e-01,  8.42248723e-02,
        -9.11155120e-02],
       [-1.49975002e-01, -1.69416200e-02,  2.03275681e-03,
         3.08024809e-02, -1.28081590e-02,  1.18468963e-01,
        -7.31947795e-02,  2.10938901e-01,  5.79604283e-02,
        -1.06384277e-01],
       [-2.44300172e-01,  6.77020177e-02,  1.61827058e-02,
         9.77846682e-02, -2.14450657e-02,  8.76296014e-02,
         1.55660659e-02,  2.56645411e-01, -6.94077387e-02,
         1.82542913e-02],
       [-3.24441910e-01,  2.83106230e-02,  1.15296148e-01,
        -6.49778843e-02, -3.93164232e-02,  2.09751099e-01,
         1.58456087e-01,  2.03075439e-01,  1.45919517e-01,
        -8.07187557e-02],
       [-1.77742794e-01, -3.47406045e-02,  6.37909994e-02,
         5.72632812e-02, -1.67798519e-01, -9.77907851e-02,
        -6.33480251e-02,  5.98776974e-02, -1.48319647e-01,
        -3.26665044e-02],
       [-1.92516297e-02, -4.32192907e-02,  9.45950896e-02,
        -1.24730960e-01,  3.15439701e-03,  7.49434829e-02,
         1.42610222e-01,  1.64739519e-01,  1.35794416e-01,
        -2.33872890e-01],
       [-9.74408463e-02, -4.51198146e-02, -7.16688111e-02,
         1.52820855e-01,  3.08901221e-02, -8.07915181e-02,
        -8.59454572e-02,  1.73750147e-01, -4.14928459e-02,
        -1.02175683e-01],
       [-1.79451153e-01,  7.97335058e-02,  6.08496368e-02,
        -8.74251127e-05,  1.40254274e-02,  7.78948367e-02,
         1.22523680e-02,  1.38402849e-01, -2.44962424e-03,
        -8.56248587e-02],
       [-7.16196820e-02, -3.66464853e-02, -1.97902359e-02,
        -3.42466384e-02,  1.01994909e-02,  8.11903924e-02,
         1.02423221e-01,  8.15625191e-02,  9.28392410e-02,
        -1.61639646e-01],
       [-1.29672050e-01, -9.39578265e-02, -3.77402268e-02,
        -5.66408038e-03,  2.01772340e-02, -5.53961843e-04,
         1.12603299e-01,  1.18293904e-01,  7.59286210e-02,
        -1.05032220e-01],
       [ 3.13648432e-02,  2.04140544e-02,  8.68844241e-02,
         8.54840502e-03, -3.24598253e-02,  7.13473856e-02,
         1.01958007e-01,  1.58244759e-01,  4.33884151e-02,
        -1.56489074e-01],
       [-5.69176152e-02, -8.68148059e-02,  5.83150014e-02,
        -6.94776773e-02, -1.14257783e-01,  9.14709717e-02,
        -6.18093796e-02,  4.60445434e-02,  6.21100292e-02,
        -2.56335258e-01],
       [-1.00941956e-03, -9.87592638e-02,  1.59144640e-01,
         2.46649459e-02, -1.47723123e-01,  3.34706903e-03,
        -1.25270292e-01,  7.13937655e-02, -3.65925357e-02,
        -2.86379248e-01],
       [-2.52649784e-01, -1.80219673e-02,  1.53900415e-01,
        -7.60671049e-02, -4.30139415e-02,  6.14799336e-02,
         5.27559966e-02,  3.91793013e-01,  1.10363506e-01,
        -2.21582249e-01],
       [-1.04441456e-02, -5.70102595e-02, -5.45391962e-02,
        -6.66194037e-02,  3.30452994e-02,  4.31669690e-03,
        -1.39387622e-02,  1.50821537e-01,  7.82721266e-02,
        -1.13290384e-01],
       [-1.50469467e-01, -1.50829509e-01,  1.37116134e-01,
        -7.71817416e-02, -1.22132301e-01,  8.29393342e-02,
         7.44771212e-03,  1.10161960e-01,  5.23409843e-02,
        -1.67824954e-01],
       [-1.67705536e-01, -1.61053427e-02,  3.56741399e-02,
        -8.12948644e-02, -2.15860698e-02,  7.68682212e-02,
         3.90296578e-02,  8.14016312e-02,  1.20665669e-01,
        -5.40915243e-02],
       [-1.74987361e-01,  5.39990142e-03,  7.59589747e-02,
         1.13510445e-01, -3.19063663e-02, -5.98092973e-02,
        -4.05801088e-02,  2.37588376e-01, -6.73733801e-02,
        -1.72320567e-02],
       [-1.80301860e-01,  2.00746767e-02, -7.40496814e-03,
         8.36828053e-02,  9.17709470e-02,  1.46025598e-01,
        -2.91051138e-02,  2.14360297e-01, -3.91696244e-02,
        -1.15331344e-01],
       [-7.45102018e-02,  3.96583155e-02,  8.10021013e-02,
         1.56707764e-02, -2.35380158e-02,  1.56681970e-01,
        -1.12800300e-02,  3.64681214e-01,  1.12793013e-01,
        -9.20613408e-02],
       [-1.10700965e-01, -3.84411961e-03,  7.15886354e-02,
        -5.16710430e-03, -2.68637538e-02, -4.64520939e-02,
        -1.02423206e-01,  1.41418934e-01,  1.36580504e-02,
        -2.16841191e-01],
       [-1.03602912e-02, -1.36248600e-02, -8.44807327e-02,
        -3.93018406e-03,  6.54329583e-02, -1.54229663e-02,
        -9.10714716e-02,  1.13576502e-02,  6.24551401e-02,
        -1.10215969e-01],
       [-1.64637700e-01, -4.25843447e-02, -6.63272589e-02,
         1.01544857e-02,  9.00160298e-02,  1.41169682e-01,
         9.43019092e-02,  1.50300652e-01,  1.17022656e-01,
        -2.61101604e-01],
       [-2.96755701e-01,  1.48339659e-01,  5.29592186e-02,
         4.51779664e-02, -6.84008598e-02,  1.29287004e-01,
         1.34066977e-02,  1.68794006e-01, -1.53631158e-02,
        -1.40826374e-01],
       [-2.27824658e-01, -3.58637236e-02,  7.98013210e-02,
        -2.93148141e-02, -1.29889801e-01,  1.07304119e-02,
         6.16377033e-02,  2.38016129e-01,  1.68460131e-01,
        -2.78131723e-01],
       [-1.97686747e-01, -1.20533034e-01,  1.91476271e-02,
        -2.50333622e-02, -1.20231688e-01, -1.43363982e-01,
        -5.45644462e-02,  1.13663480e-01, -9.71207619e-02,
        -7.38224685e-02],
       [-1.21181801e-01, -9.18156952e-02,  1.72619522e-02,
         7.20846877e-02, -5.00237271e-02, -7.88232982e-02,
        -2.75398232e-02,  9.42765027e-02, -8.18064660e-02,
        -4.43772227e-02],
       [-2.12152809e-01, -1.05831539e-02,  1.12541884e-01,
         3.79703306e-02, -4.97136004e-02, -8.26531351e-02,
         4.28089425e-02,  2.72401571e-01, -9.41082910e-02,
        -8.25358368e-03],
       [-2.12490350e-01,  5.10787666e-02, -4.91231680e-03,
         1.71558380e-01,  8.33496898e-02,  8.03120583e-02,
         5.97136915e-02,  2.78716445e-01, -5.66011816e-02,
        -7.99765587e-02],
       [-2.45497763e-01, -5.21367639e-02,  1.77163050e-01,
         8.67958441e-02, -1.33168459e-01,  9.83412005e-03,
        -1.34591311e-01,  1.48744047e-01, -6.65533617e-02,
        -1.07505932e-01],
       [-1.36525869e-01, -5.12802340e-02,  2.54329219e-02,
         8.01228657e-02, -3.24120894e-02, -6.36913255e-03,
        -7.75915161e-02,  1.81387305e-01,  6.72850609e-02,
        -1.06104709e-01],
       [-8.19087848e-02, -6.67821616e-02,  1.09396182e-01,
        -8.99944529e-02, -1.08385280e-01,  6.29347712e-02,
         7.26154894e-02,  1.68957621e-01,  1.90485001e-01,
        -2.60798335e-01],
       [-1.76897705e-01,  4.90825251e-02,  2.94402167e-02,
        -2.41212249e-02,  3.94896790e-02,  1.18754521e-01,
         1.69773921e-02,  1.10196158e-01,  7.08303824e-02,
        -6.86142594e-02],
       [-1.29656106e-01, -8.14089552e-02,  1.14682741e-01,
        -1.32834181e-01, -1.49253279e-01, -2.83164792e-02,
         3.45680863e-04,  2.52322882e-01,  2.89388448e-02,
        -2.79281288e-01],
       [-1.10502213e-01,  1.07094124e-01,  3.24486196e-02,
         7.70951509e-02, -6.27939776e-02,  1.68845624e-01,
        -1.44310594e-01,  1.45337492e-01,  2.03377791e-02,
        -5.04231378e-02],
       [-2.66523331e-01, -7.49082193e-02,  1.91363335e-01,
        -6.39847219e-02, -1.04055285e-01,  8.31385702e-02,
         8.82939398e-02,  1.99207246e-01,  5.35239354e-02,
        -2.60884434e-01],
       [-1.35722771e-01,  3.94147262e-02, -6.39424995e-02,
         1.39283150e-01,  5.37211001e-02, -6.34303223e-03,
        -1.70467123e-01,  2.55692095e-01, -7.66103566e-02,
        -6.90388680e-02],
       [-1.07885860e-01,  2.30858717e-02,  8.21547359e-02,
        -3.12240291e-02, -9.89983678e-02,  7.22398609e-02,
        -4.08478230e-02,  8.69123414e-02,  4.48577479e-02,
        -6.41947538e-02],
       [-2.28321850e-02, -3.88411283e-02,  1.47033811e-01,
        -2.35385150e-01, -9.87000838e-02,  6.44287840e-02,
        -1.87633559e-02,  1.17905587e-01,  9.70625877e-02,
        -2.46781930e-01],
       [-8.77917856e-02, -1.64044406e-02,  7.53755122e-02,
        -8.24043527e-04, -7.77238905e-02,  1.16269790e-01,
        -1.00877963e-01,  8.79124254e-02,  3.39440927e-02,
        -5.94997481e-02],
       [-1.41677827e-01, -1.40151009e-02,  8.84927809e-04,
         1.03166051e-01, -1.66242346e-02,  2.62837298e-02,
        -1.33589238e-01,  1.65735006e-01,  3.65820900e-02,
        -1.46895535e-02],
       [-1.61557034e-01,  5.66626638e-02, -1.61597617e-02,
         2.58595943e-02,  3.39905620e-02,  1.01104185e-01,
        -3.71510983e-02,  1.20341092e-01,  3.26242894e-02,
        -4.07250933e-02],
       [-2.17516154e-01,  7.85727724e-02,  9.79433060e-02,
         6.97179586e-02,  4.95264679e-02,  1.92503840e-01,
        -4.96265218e-02,  1.99431688e-01, -5.32730669e-03,
        -2.50038877e-02],
       [-1.35356426e-01, -6.96291253e-02,  3.92658785e-02,
        -9.86322537e-02, -4.20986377e-02,  9.87840891e-02,
         9.67663303e-02,  1.76262826e-01,  9.44406465e-02,
        -2.23472387e-01],
       [-1.25066608e-01,  7.71146417e-02,  4.02672291e-02,
        -2.05352344e-02,  3.11498251e-02,  9.64582711e-02,
        -5.39951548e-02,  2.29750067e-01,  1.61451437e-02,
        -5.41997403e-02],
       [-1.93750665e-01, -3.56721133e-03, -1.50568932e-02,
         1.78796798e-02,  8.33508372e-03, -1.18013099e-02,
        -5.35021350e-02,  2.02244624e-01,  3.02494057e-02,
        -1.20312274e-01],
       [-2.62067527e-01,  2.36408859e-02,  5.58489896e-02,
         1.75756812e-01, -2.75299139e-02,  3.48872915e-02,
         5.41301072e-03,  3.15880209e-01, -5.74782193e-02,
         7.00992346e-03],
       [-2.76674211e-01, -2.08131559e-02, -1.26259401e-02,
         7.77718723e-02, -1.54706314e-01,  1.31996438e-01,
         2.20355690e-02,  5.61908968e-02,  3.73308063e-02,
        -1.17717944e-01],
       [-1.59806639e-01,  1.20503023e-01, -4.36934829e-03,
         1.16428092e-01,  5.47975339e-02,  1.25162587e-01,
         4.78192419e-02,  1.28253624e-01,  7.34245628e-02,
        -1.80039048e-01],
       [-2.67963678e-01,  6.00077920e-02,  1.13472804e-01,
         7.52071738e-02, -6.40357211e-02,  1.03171021e-01,
         1.48901194e-01,  1.97019696e-01,  3.76104042e-02,
        -1.68720663e-01],
       [-2.01240778e-01,  2.47026011e-02,  3.10055390e-02,
        -8.58910009e-03, -8.49897265e-02, -7.54948407e-02,
        -9.39515531e-02,  1.34306327e-01, -1.71037674e-01,
        -5.76597378e-02],
       [-5.20152375e-02,  6.59879148e-02, -3.30656916e-02,
         9.97125208e-02,  3.56362388e-02,  1.26982957e-01,
        -2.69417539e-02,  1.59046397e-01,  1.10872082e-01,
        -1.84650719e-01]], dtype=float32)>}

คุณยังสามารถโหลดและทำการอนุมานในลักษณะกระจาย:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

การเรียกใช้ฟังก์ชันที่เรียกคืนเป็นเพียงการส่งต่อไปยังโมเดลที่บันทึกไว้ (คาดการณ์) จะเกิดอะไรขึ้นหากคุณไม่ต้องการฝึกฟังก์ชั่นที่โหลดต่อไป หรือฝังฟังก์ชั่นที่โหลดลงในโมเดลที่ใหญ่กว่า? แนวทางปฏิบัติทั่วไปคือการห่อวัตถุที่โหลดนี้เข้ากับเลเยอร์ Keras เพื่อให้ได้สิ่งนี้ โชคดีที่ TF Hub มี ฮับ KerasLayer สำหรับจุดประสงค์นี้แสดงไว้ที่นี่:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Epoch 1/2
938/938 [==============================] - 2s 3ms/step - loss: 0.1981 - sparse_categorical_accuracy: 0.9412
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0655 - sparse_categorical_accuracy: 0.9804

อย่างที่คุณเห็น hub.KerasLayer ห่อผลลัพธ์ที่โหลดกลับจาก tf.saved_model.load() ลงในเลเยอร์ Keras ที่สามารถใช้สร้างโมเดลอื่นได้ สิ่งนี้มีประโยชน์มากสำหรับการถ่ายทอดการเรียนรู้

ฉันควรใช้ API ใด

สำหรับการประหยัดหากคุณกำลังทำงานกับโมเดล Keras ขอแนะนำให้ใช้ Keras's model.save() API เกือบตลอดเวลา หากสิ่งที่คุณกำลังประหยัดไม่ใช่โมเดล Keras ดังนั้น API ระดับล่างคือทางเลือกเดียวของคุณ

สำหรับการโหลด API ใดที่คุณใช้ขึ้นอยู่กับสิ่งที่คุณต้องการได้รับจากโหลด API หากคุณไม่สามารถ (หรือไม่ต้องการ) รับโมเดล Keras ให้ใช้ tf.saved_model.load() มิฉะนั้นให้ใช้ tf.keras.models.load_model() โปรดทราบว่าคุณจะได้รับโมเดล Keras กลับคืนมาก็ต่อเมื่อคุณบันทึกโมเดล Keras ไว้

เป็นไปได้ที่จะผสมผสานและจับคู่ API คุณสามารถบันทึกโมเดล Keras ด้วย model.save และโหลดโมเดลที่ไม่ใช่ Keras ด้วย API ระดับต่ำ tf.saved_model.load

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

บันทึก / โหลดจากอุปกรณ์ภายในเครื่อง

เมื่อบันทึกและโหลดจากอุปกรณ์ io ในเครื่องในขณะที่เรียกใช้จากระยะไกลตัวอย่างเช่นการใช้ Cloud TPU ต้องใช้ตัวเลือก experimental_io_device เพื่อตั้งค่าอุปกรณ์ io เป็น localhost

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

ข้อควรระวัง

กรณีพิเศษคือเมื่อคุณมีโมเดล Keras ที่ไม่มีอินพุตที่กำหนดไว้อย่างดี ตัวอย่างเช่นสามารถสร้างแบบจำลองตามลำดับได้โดยไม่ต้องป้อนรูปร่างใด ๆ ( Sequential([Dense(3), ...] ) โมเดลคลาสย่อยยังไม่มีอินพุตที่กำหนดไว้อย่างชัดเจนหลังจากการเริ่มต้นในกรณีนี้คุณควรยึดติดกับ API ระดับล่างทั้งในการบันทึกและการโหลดมิฉะนั้นคุณจะได้รับข้อผิดพลาด

ในการตรวจสอบว่าโมเดลของคุณมีอินพุตที่กำหนดไว้อย่างดีหรือไม่เพียงแค่ตรวจสอบว่า model.inputs เป็น None ถ้าไม่ใช่ None คุณก็สบายดี รูปร่างอินพุตจะถูกกำหนดโดยอัตโนมัติเมื่อใช้โมเดลใน. .fit , .evaluate , .predict หรือเมื่อเรียกโมเดล ( model(inputs) )

นี่คือตัวอย่าง:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets