דף זה תורגם על ידי Cloud Translation API.
Switch to English

שמור וטען מודל באמצעות אסטרטגיית הפצה

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHub הורד מחברת

סקירה כללית

מקובל לשמור ולהעמיס מודל במהלך האימון. ישנן שתי קבוצות של ממשקי API לשמירה וטעינה של מודל keras: API ברמה גבוהה ו- API ברמה נמוכה. מדריך זה מדגים כיצד ניתן להשתמש בממשקי ה- API של SavedModel בעת שימוש ב- tf.distribute.Strategy . כדי ללמוד על SavedModel וסידור באופן כללי, אנא קרא את מדריך הדגמים השמורים ואת מדריך הסידור של דגם Keras . נתחיל בדוגמה פשוטה:

ייבוא ​​תלות:

import tensorflow_datasets as tfds

import tensorflow as tf

הכן את הנתונים והמודל באמצעות tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

לאמן את הדגם:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 4s 4ms/step - loss: 0.2095 - sparse_categorical_accuracy: 0.9386
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0730 - sparse_categorical_accuracy: 0.9787

<tensorflow.python.keras.callbacks.History at 0x7f7470042b38>

שמור וטען את המודל

עכשיו שיש לך מודל פשוט לעבוד איתו, בואו נסתכל על ממשקי ה- API של שמירה / טעינה. קיימות שתי קבוצות של ממשקי API:

ממשקי ה- API של Keras

הנה דוגמה לשמירה וטעינה של מודל עם ממשקי ה- API של Keras:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

שחזר את המודל ללא tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0539 - sparse_categorical_accuracy: 0.9838
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0381 - sparse_categorical_accuracy: 0.9884

<tensorflow.python.keras.callbacks.History at 0x7f74d333f780>

לאחר שחזור המודל, תוכלו להמשיך ולהתאמן עליו, גם מבלי שתצטרכו להתקשר שוב compile() , מכיוון שהוא כבר נערך לפני שמירתו. הדגם נשמר בפורמט הפרוטו הסטנדרטי של TensorFlow של SavedModel . לקבלת מידע נוסף, עיין במדריך לפורמט saved_model .

עכשיו כדי לטעון את המודל ולאמן אותו באמצעות tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 10ms/step - loss: 0.0530 - sparse_categorical_accuracy: 0.9844
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0388 - sparse_categorical_accuracy: 0.9882

כפי שאתה יכול לראות, טעינה עובדת כצפוי עם tf.distribute.Strategy . האסטרטגיה הנהוגה כאן אינה חייבת להיות אותה אסטרטגיה בה נעשה שימוש לפני שמירה.

ממשקי ה- API של tf.saved_model

עכשיו בואו נסתכל על ממשקי ה- API ברמה התחתונה. שמירת המודל דומה ל- API של keras:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

ניתן לבצע tf.saved_model.load() באמצעות tf.saved_model.load() . עם זאת, מכיוון שמדובר ב- API שנמצא ברמה הנמוכה יותר (ומכאן שיש לו מגוון רחב יותר של מקרי שימוש), הוא אינו מחזיר מודל Keras. במקום זאת, הוא מחזיר אובייקט המכיל פונקציות שניתן להשתמש בהן כדי לבצע הסקה. לדוגמה:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

האובייקט הטעון עשוי להכיל פונקציות מרובות, כל אחת משויכת למפתח. "serving_default" הוא מקש ברירת המחדל עבור פונקציית ההיסק עם דגם Keras שמור. כדי להסיק עם פונקציה זו:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.46878400e-01, -2.84028575e-02,  4.34195548e-02,
         8.65758881e-02, -5.50181568e-02, -2.26117969e-02,
        -8.18806365e-02,  1.60868585e-01,  7.05277026e-02,
        -2.11526364e-01],
       [-2.04405725e-01, -2.38965377e-02,  1.06097549e-01,
         1.15776211e-02, -5.68305999e-02,  7.61558264e-02,
        -2.36685127e-02,  6.12710230e-02,  6.85455352e-02,
        -2.04084530e-01],
       [-1.70060426e-01,  6.82905912e-02, -2.54967008e-02,
         1.27377272e-01, -4.24135383e-03, -1.15118716e-02,
         1.65115029e-01,  1.64797649e-01,  8.41001868e-02,
        -2.60865986e-01],
       [-1.24608956e-01,  7.05861971e-02,  4.76837084e-02,
         9.51382518e-02, -1.36017501e-02,  9.53883678e-02,
        -2.60323286e-04,  1.26946449e-01, -9.98851806e-02,
         6.01550192e-02],
       [-8.42214674e-02, -4.93131615e-02, -5.85474074e-04,
        -3.79234888e-02, -6.78482801e-02,  9.56373289e-02,
         4.69041206e-02,  8.55031833e-02,  9.31831449e-02,
        -1.40825540e-01],
       [-1.46941900e-01,  1.22972876e-02,  5.79140112e-02,
        -7.50405565e-02,  6.13511279e-02,  1.14746153e-01,
         3.54535617e-02,  2.55915433e-01,  7.26796240e-02,
        -1.99857190e-01],
       [-2.07879156e-01,  1.83034241e-02,  1.57775074e-01,
         6.06807172e-02, -1.75382420e-02,  1.33817732e-01,
         1.36331618e-01,  2.02472329e-01,  3.72610986e-02,
        -1.31865010e-01],
       [-9.93705392e-02,  6.03869818e-02, -4.28698361e-02,
         6.31842762e-04,  8.84034038e-02,  6.72685653e-02,
        -2.09506359e-02,  1.97081745e-01,  7.39021823e-02,
        -1.64300233e-01],
       [-9.71228778e-02,  5.48233166e-02,  1.38393641e-02,
        -7.14895800e-02, -3.87909710e-02,  8.45830888e-04,
        -3.62640694e-02,  1.64835989e-01,  5.04231751e-02,
        -2.07461655e-01],
       [-2.92240772e-02,  1.45425312e-02,  5.74428178e-02,
        -1.34241190e-02, -1.80013701e-02,  7.78546855e-02,
        -8.48746449e-02,  9.98296142e-02,  6.38790280e-02,
        -5.32845445e-02],
       [-1.76605240e-01, -1.42511949e-01,  1.39559209e-01,
        -2.00123414e-02, -6.44349307e-02, -4.56911251e-02,
         2.01093405e-03,  1.59898788e-01,  1.95391588e-02,
        -1.61375850e-01],
       [-1.58091724e-01,  6.25609234e-03,  2.12391287e-01,
        -1.39106885e-01, -4.78955358e-02,  7.36434534e-02,
         7.29984716e-02,  2.28351891e-01,  1.23042218e-01,
        -2.22285807e-01],
       [-6.63312748e-02, -5.25613949e-02,  3.88407931e-02,
         4.74876724e-02, -3.56937200e-02,  1.11578718e-01,
        -8.47167745e-02,  1.54049486e-01,  8.42248723e-02,
        -9.11155120e-02],
       [-1.49975002e-01, -1.69416200e-02,  2.03275681e-03,
         3.08024809e-02, -1.28081590e-02,  1.18468963e-01,
        -7.31947795e-02,  2.10938901e-01,  5.79604283e-02,
        -1.06384277e-01],
       [-2.44300172e-01,  6.77020177e-02,  1.61827058e-02,
         9.77846682e-02, -2.14450657e-02,  8.76296014e-02,
         1.55660659e-02,  2.56645411e-01, -6.94077387e-02,
         1.82542913e-02],
       [-3.24441910e-01,  2.83106230e-02,  1.15296148e-01,
        -6.49778843e-02, -3.93164232e-02,  2.09751099e-01,
         1.58456087e-01,  2.03075439e-01,  1.45919517e-01,
        -8.07187557e-02],
       [-1.77742794e-01, -3.47406045e-02,  6.37909994e-02,
         5.72632812e-02, -1.67798519e-01, -9.77907851e-02,
        -6.33480251e-02,  5.98776974e-02, -1.48319647e-01,
        -3.26665044e-02],
       [-1.92516297e-02, -4.32192907e-02,  9.45950896e-02,
        -1.24730960e-01,  3.15439701e-03,  7.49434829e-02,
         1.42610222e-01,  1.64739519e-01,  1.35794416e-01,
        -2.33872890e-01],
       [-9.74408463e-02, -4.51198146e-02, -7.16688111e-02,
         1.52820855e-01,  3.08901221e-02, -8.07915181e-02,
        -8.59454572e-02,  1.73750147e-01, -4.14928459e-02,
        -1.02175683e-01],
       [-1.79451153e-01,  7.97335058e-02,  6.08496368e-02,
        -8.74251127e-05,  1.40254274e-02,  7.78948367e-02,
         1.22523680e-02,  1.38402849e-01, -2.44962424e-03,
        -8.56248587e-02],
       [-7.16196820e-02, -3.66464853e-02, -1.97902359e-02,
        -3.42466384e-02,  1.01994909e-02,  8.11903924e-02,
         1.02423221e-01,  8.15625191e-02,  9.28392410e-02,
        -1.61639646e-01],
       [-1.29672050e-01, -9.39578265e-02, -3.77402268e-02,
        -5.66408038e-03,  2.01772340e-02, -5.53961843e-04,
         1.12603299e-01,  1.18293904e-01,  7.59286210e-02,
        -1.05032220e-01],
       [ 3.13648432e-02,  2.04140544e-02,  8.68844241e-02,
         8.54840502e-03, -3.24598253e-02,  7.13473856e-02,
         1.01958007e-01,  1.58244759e-01,  4.33884151e-02,
        -1.56489074e-01],
       [-5.69176152e-02, -8.68148059e-02,  5.83150014e-02,
        -6.94776773e-02, -1.14257783e-01,  9.14709717e-02,
        -6.18093796e-02,  4.60445434e-02,  6.21100292e-02,
        -2.56335258e-01],
       [-1.00941956e-03, -9.87592638e-02,  1.59144640e-01,
         2.46649459e-02, -1.47723123e-01,  3.34706903e-03,
        -1.25270292e-01,  7.13937655e-02, -3.65925357e-02,
        -2.86379248e-01],
       [-2.52649784e-01, -1.80219673e-02,  1.53900415e-01,
        -7.60671049e-02, -4.30139415e-02,  6.14799336e-02,
         5.27559966e-02,  3.91793013e-01,  1.10363506e-01,
        -2.21582249e-01],
       [-1.04441456e-02, -5.70102595e-02, -5.45391962e-02,
        -6.66194037e-02,  3.30452994e-02,  4.31669690e-03,
        -1.39387622e-02,  1.50821537e-01,  7.82721266e-02,
        -1.13290384e-01],
       [-1.50469467e-01, -1.50829509e-01,  1.37116134e-01,
        -7.71817416e-02, -1.22132301e-01,  8.29393342e-02,
         7.44771212e-03,  1.10161960e-01,  5.23409843e-02,
        -1.67824954e-01],
       [-1.67705536e-01, -1.61053427e-02,  3.56741399e-02,
        -8.12948644e-02, -2.15860698e-02,  7.68682212e-02,
         3.90296578e-02,  8.14016312e-02,  1.20665669e-01,
        -5.40915243e-02],
       [-1.74987361e-01,  5.39990142e-03,  7.59589747e-02,
         1.13510445e-01, -3.19063663e-02, -5.98092973e-02,
        -4.05801088e-02,  2.37588376e-01, -6.73733801e-02,
        -1.72320567e-02],
       [-1.80301860e-01,  2.00746767e-02, -7.40496814e-03,
         8.36828053e-02,  9.17709470e-02,  1.46025598e-01,
        -2.91051138e-02,  2.14360297e-01, -3.91696244e-02,
        -1.15331344e-01],
       [-7.45102018e-02,  3.96583155e-02,  8.10021013e-02,
         1.56707764e-02, -2.35380158e-02,  1.56681970e-01,
        -1.12800300e-02,  3.64681214e-01,  1.12793013e-01,
        -9.20613408e-02],
       [-1.10700965e-01, -3.84411961e-03,  7.15886354e-02,
        -5.16710430e-03, -2.68637538e-02, -4.64520939e-02,
        -1.02423206e-01,  1.41418934e-01,  1.36580504e-02,
        -2.16841191e-01],
       [-1.03602912e-02, -1.36248600e-02, -8.44807327e-02,
        -3.93018406e-03,  6.54329583e-02, -1.54229663e-02,
        -9.10714716e-02,  1.13576502e-02,  6.24551401e-02,
        -1.10215969e-01],
       [-1.64637700e-01, -4.25843447e-02, -6.63272589e-02,
         1.01544857e-02,  9.00160298e-02,  1.41169682e-01,
         9.43019092e-02,  1.50300652e-01,  1.17022656e-01,
        -2.61101604e-01],
       [-2.96755701e-01,  1.48339659e-01,  5.29592186e-02,
         4.51779664e-02, -6.84008598e-02,  1.29287004e-01,
         1.34066977e-02,  1.68794006e-01, -1.53631158e-02,
        -1.40826374e-01],
       [-2.27824658e-01, -3.58637236e-02,  7.98013210e-02,
        -2.93148141e-02, -1.29889801e-01,  1.07304119e-02,
         6.16377033e-02,  2.38016129e-01,  1.68460131e-01,
        -2.78131723e-01],
       [-1.97686747e-01, -1.20533034e-01,  1.91476271e-02,
        -2.50333622e-02, -1.20231688e-01, -1.43363982e-01,
        -5.45644462e-02,  1.13663480e-01, -9.71207619e-02,
        -7.38224685e-02],
       [-1.21181801e-01, -9.18156952e-02,  1.72619522e-02,
         7.20846877e-02, -5.00237271e-02, -7.88232982e-02,
        -2.75398232e-02,  9.42765027e-02, -8.18064660e-02,
        -4.43772227e-02],
       [-2.12152809e-01, -1.05831539e-02,  1.12541884e-01,
         3.79703306e-02, -4.97136004e-02, -8.26531351e-02,
         4.28089425e-02,  2.72401571e-01, -9.41082910e-02,
        -8.25358368e-03],
       [-2.12490350e-01,  5.10787666e-02, -4.91231680e-03,
         1.71558380e-01,  8.33496898e-02,  8.03120583e-02,
         5.97136915e-02,  2.78716445e-01, -5.66011816e-02,
        -7.99765587e-02],
       [-2.45497763e-01, -5.21367639e-02,  1.77163050e-01,
         8.67958441e-02, -1.33168459e-01,  9.83412005e-03,
        -1.34591311e-01,  1.48744047e-01, -6.65533617e-02,
        -1.07505932e-01],
       [-1.36525869e-01, -5.12802340e-02,  2.54329219e-02,
         8.01228657e-02, -3.24120894e-02, -6.36913255e-03,
        -7.75915161e-02,  1.81387305e-01,  6.72850609e-02,
        -1.06104709e-01],
       [-8.19087848e-02, -6.67821616e-02,  1.09396182e-01,
        -8.99944529e-02, -1.08385280e-01,  6.29347712e-02,
         7.26154894e-02,  1.68957621e-01,  1.90485001e-01,
        -2.60798335e-01],
       [-1.76897705e-01,  4.90825251e-02,  2.94402167e-02,
        -2.41212249e-02,  3.94896790e-02,  1.18754521e-01,
         1.69773921e-02,  1.10196158e-01,  7.08303824e-02,
        -6.86142594e-02],
       [-1.29656106e-01, -8.14089552e-02,  1.14682741e-01,
        -1.32834181e-01, -1.49253279e-01, -2.83164792e-02,
         3.45680863e-04,  2.52322882e-01,  2.89388448e-02,
        -2.79281288e-01],
       [-1.10502213e-01,  1.07094124e-01,  3.24486196e-02,
         7.70951509e-02, -6.27939776e-02,  1.68845624e-01,
        -1.44310594e-01,  1.45337492e-01,  2.03377791e-02,
        -5.04231378e-02],
       [-2.66523331e-01, -7.49082193e-02,  1.91363335e-01,
        -6.39847219e-02, -1.04055285e-01,  8.31385702e-02,
         8.82939398e-02,  1.99207246e-01,  5.35239354e-02,
        -2.60884434e-01],
       [-1.35722771e-01,  3.94147262e-02, -6.39424995e-02,
         1.39283150e-01,  5.37211001e-02, -6.34303223e-03,
        -1.70467123e-01,  2.55692095e-01, -7.66103566e-02,
        -6.90388680e-02],
       [-1.07885860e-01,  2.30858717e-02,  8.21547359e-02,
        -3.12240291e-02, -9.89983678e-02,  7.22398609e-02,
        -4.08478230e-02,  8.69123414e-02,  4.48577479e-02,
        -6.41947538e-02],
       [-2.28321850e-02, -3.88411283e-02,  1.47033811e-01,
        -2.35385150e-01, -9.87000838e-02,  6.44287840e-02,
        -1.87633559e-02,  1.17905587e-01,  9.70625877e-02,
        -2.46781930e-01],
       [-8.77917856e-02, -1.64044406e-02,  7.53755122e-02,
        -8.24043527e-04, -7.77238905e-02,  1.16269790e-01,
        -1.00877963e-01,  8.79124254e-02,  3.39440927e-02,
        -5.94997481e-02],
       [-1.41677827e-01, -1.40151009e-02,  8.84927809e-04,
         1.03166051e-01, -1.66242346e-02,  2.62837298e-02,
        -1.33589238e-01,  1.65735006e-01,  3.65820900e-02,
        -1.46895535e-02],
       [-1.61557034e-01,  5.66626638e-02, -1.61597617e-02,
         2.58595943e-02,  3.39905620e-02,  1.01104185e-01,
        -3.71510983e-02,  1.20341092e-01,  3.26242894e-02,
        -4.07250933e-02],
       [-2.17516154e-01,  7.85727724e-02,  9.79433060e-02,
         6.97179586e-02,  4.95264679e-02,  1.92503840e-01,
        -4.96265218e-02,  1.99431688e-01, -5.32730669e-03,
        -2.50038877e-02],
       [-1.35356426e-01, -6.96291253e-02,  3.92658785e-02,
        -9.86322537e-02, -4.20986377e-02,  9.87840891e-02,
         9.67663303e-02,  1.76262826e-01,  9.44406465e-02,
        -2.23472387e-01],
       [-1.25066608e-01,  7.71146417e-02,  4.02672291e-02,
        -2.05352344e-02,  3.11498251e-02,  9.64582711e-02,
        -5.39951548e-02,  2.29750067e-01,  1.61451437e-02,
        -5.41997403e-02],
       [-1.93750665e-01, -3.56721133e-03, -1.50568932e-02,
         1.78796798e-02,  8.33508372e-03, -1.18013099e-02,
        -5.35021350e-02,  2.02244624e-01,  3.02494057e-02,
        -1.20312274e-01],
       [-2.62067527e-01,  2.36408859e-02,  5.58489896e-02,
         1.75756812e-01, -2.75299139e-02,  3.48872915e-02,
         5.41301072e-03,  3.15880209e-01, -5.74782193e-02,
         7.00992346e-03],
       [-2.76674211e-01, -2.08131559e-02, -1.26259401e-02,
         7.77718723e-02, -1.54706314e-01,  1.31996438e-01,
         2.20355690e-02,  5.61908968e-02,  3.73308063e-02,
        -1.17717944e-01],
       [-1.59806639e-01,  1.20503023e-01, -4.36934829e-03,
         1.16428092e-01,  5.47975339e-02,  1.25162587e-01,
         4.78192419e-02,  1.28253624e-01,  7.34245628e-02,
        -1.80039048e-01],
       [-2.67963678e-01,  6.00077920e-02,  1.13472804e-01,
         7.52071738e-02, -6.40357211e-02,  1.03171021e-01,
         1.48901194e-01,  1.97019696e-01,  3.76104042e-02,
        -1.68720663e-01],
       [-2.01240778e-01,  2.47026011e-02,  3.10055390e-02,
        -8.58910009e-03, -8.49897265e-02, -7.54948407e-02,
        -9.39515531e-02,  1.34306327e-01, -1.71037674e-01,
        -5.76597378e-02],
       [-5.20152375e-02,  6.59879148e-02, -3.30656916e-02,
         9.97125208e-02,  3.56362388e-02,  1.26982957e-01,
        -2.69417539e-02,  1.59046397e-01,  1.10872082e-01,
        -1.84650719e-01]], dtype=float32)>}

ניתן גם לטעון ולעשות מסקנות באופן מבוזר:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

קריאה לפונקציה המשוחזרת היא רק העברה קדימה של המודל השמור (חיזוי). מה אם אתה רוצה להמשיך לאמן את הפונקציה הטעונה? או להטמיע את הפונקציה הטעונה במודל גדול יותר? נוהג נפוץ הוא לעטוף את האובייקט הטעון הזה לשכבת Keras כדי להשיג זאת. למרבה המזל, ל- TF Hub יש hub.KerasLayer למטרה זו, המוצג כאן:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Epoch 1/2
938/938 [==============================] - 2s 3ms/step - loss: 0.1981 - sparse_categorical_accuracy: 0.9412
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0655 - sparse_categorical_accuracy: 0.9804

כפי שאתה יכול לראות, hub.KerasLayer עוטף את התוצאה הנטענת מ- tf.saved_model.load() לשכבת Keras שניתן להשתמש בה לבניית מודל אחר. זה מאוד שימושי ללימוד העברה.

באיזה API עלי להשתמש?

לצורך שמירה, אם אתה עובד עם מודל keras, כמעט תמיד מומלץ להשתמש ב- API של model.save() של model.save() . אם מה שאתה שומר אינו מודל Keras, אז ה- API ברמה התחתונה הוא הבחירה היחידה שלך.

לטעינה, איזה ממשק API שתשתמש תלוי במה ברצונך לקבל מממשק ה- API לטעינה. אם אינך יכול (או לא רוצה) להשיג מודל Keras, השתמש ב- tf.saved_model.load() . אחרת, השתמש ב- tf.keras.models.load_model() . שים לב שאתה יכול לקבל מודל Keras בחזרה רק אם שמרת מודל Keras.

אפשר לערבב ולהתאים את ממשקי ה- API. באפשרותך לשמור מודל Keras עם model.save ולהעמיס מודל שאינו Keras עם ה- API ברמה נמוכה, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

שמירה / טעינה ממכשיר מקומי

בעת שמירה וטעינה ממכשיר io מקומי תוך כדי ריצה מרחוק, למשל באמצעות TPU בענן, יש להשתמש באפשרות experimental_io_device כדי להגדיר את מכשיר ה- io ל- localhost.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

אזהרות

מקרה מיוחד הוא כאשר יש לך מודל Keras שאין לו תשומות מוגדרות היטב. לדוגמא, ניתן ליצור מודל רציף ללא צורות קלט כלשהן ( Sequential([Dense(3), ...] ). גם למודלים של מחלקות משנה אין קלטים מוגדרים היטב לאחר האתחול. במקרה זה, עליך להישאר עם ממשקי API ברמה נמוכה יותר בשמירה ובטעינה, אחרת תקבל שגיאה.

כדי לבדוק אם הדגם שלך יש כניסות מוגדרות היטב, רק לבדוק אם model.inputs הוא None . אם זה לא None , כולכם טובים. צורות קלט מוגדרות אוטומטית כאשר משתמשים במודל ב- .fit , .evaluate , .predict , או בעת קריאה למודל ( model(inputs) ).

הנה דוגמא:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets