یک مدل را با استفاده از استراتژی توزیع ذخیره و بارگذاری کنید

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHub دانلود دفترچه یادداشت

بررسی اجمالی

ذخیره و بارگذاری یک مدل در طول آموزش معمول است. دو مجموعه API برای ذخیره و بارگذاری یک مدل keras وجود دارد: یک API سطح بالا و یک API سطح پایین. این آموزش نشان می دهد که چگونه می توانید از SavedModel API ها هنگام استفاده از tf.distribute.Strategy استفاده کنید. برای آشنایی با SavedModel و به طور کلی سریال سازی، لطفا راهنمای مدل ذخیره شده و راهنمای سریال سازی مدل Keras را مطالعه کنید. بیایید با یک مثال ساده شروع کنیم:

وابستگی های وارداتی:

import tensorflow_datasets as tfds

import tensorflow as tf

داده ها و مدل را با استفاده از tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

آموزش مدل:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
2022-01-26 05:41:11.916000: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 11s 5ms/step - loss: 0.1873 - sparse_categorical_accuracy: 0.9451
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0641 - sparse_categorical_accuracy: 0.9807
<keras.callbacks.History at 0x7f3b900396d0>

مدل را ذخیره و بارگذاری کنید

اکنون که یک مدل ساده برای کار با آن دارید، بیایید نگاهی به API های ذخیره/بارگیری بیندازیم. دو مجموعه API موجود است:

API های Keras

در اینجا نمونه ای از ذخیره و بارگذاری یک مدل با API های Keras آورده شده است:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2022-01-26 05:41:26.593570: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

بازیابی مدل بدون tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9859
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0334 - sparse_categorical_accuracy: 0.9895
<keras.callbacks.History at 0x7f3b187b7150>

پس از بازیابی مدل، می‌توانید آموزش روی آن را ادامه دهید، حتی بدون نیاز به فراخوانی compile() ، زیرا قبلاً قبل از ذخیره کامپایل شده است. این مدل در قالب پروتوی استاندارد SavedModel ذخیره شده است. برای اطلاعات بیشتر، لطفاً به راهنمای قالب saved_model کنید.

اکنون برای بارگذاری مدل و آموزش آن با استفاده از tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2022-01-26 05:41:33.036733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-01-26 05:41:33.083001: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
938/938 [==============================] - 10s 10ms/step - loss: 0.0474 - sparse_categorical_accuracy: 0.9860
Epoch 2/2
938/938 [==============================] - 10s 10ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9903

همانطور که می بینید، بارگیری با tf.distribute.Strategy همانطور که انتظار می رود کار می کند. استراتژی مورد استفاده در اینجا لازم نیست همان استراتژی مورد استفاده قبل از ذخیره باشد.

API های tf.saved_model

حال بیایید نگاهی به APIهای سطح پایین بیاندازیم. ذخیره مدل شبیه به keras API است:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

بارگذاری را می توان با tf.saved_model.load() انجام داد. با این حال، از آنجایی که یک API در سطح پایین‌تر است (و از این رو دارای طیف وسیع‌تری از موارد استفاده است)، مدل Keras را برمی‌گرداند. در عوض، یک شی را برمی گرداند که حاوی توابعی است که می توان از آنها برای استنتاج استفاده کرد. مثلا:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

شیء بارگذاری شده ممکن است دارای چندین عملکرد باشد که هر کدام با یک کلید مرتبط هستند. "serving_default" کلید پیش‌فرض برای تابع استنتاج با یک مدل Keras ذخیره شده است. برای استنتاج با این تابع:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-1.18789300e-01, -1.78404614e-01,  4.92432676e-02,
        -9.37875658e-02,  1.14302970e-01, -8.99422392e-02,
         9.47709680e-02, -7.75382966e-02,  4.04430032e-02,
         2.41404288e-02],
       [-2.35370561e-01, -3.39397341e-02,  2.73427293e-02,
        -1.08200148e-01,  5.10682352e-02,  1.36142194e-01,
         9.28785652e-02, -5.35808355e-02,  2.56292164e-01,
         1.05301209e-01],
       [-1.91031799e-01, -7.72745535e-02, -7.23153427e-02,
        -1.99329913e-01, -7.45072216e-02,  2.42738128e-02,
         2.07733169e-01, -3.15396488e-03,  4.95976806e-02,
         2.14848563e-01],
       [-9.82482210e-02, -6.13910556e-02,  1.00815810e-01,
        -1.87558904e-01,  1.14685424e-01,  1.53835595e-01,
         1.85714245e-01, -8.74890238e-02,  1.07493028e-01,
         1.57510787e-02],
       [-8.56257528e-02,  3.23683321e-02, -3.66768315e-02,
        -1.47201523e-01, -5.31517603e-02,  1.52744055e-02,
         1.69184029e-01, -5.42814359e-02,  1.11524366e-01,
         5.65215349e-02],
       [-1.50604844e-01, -7.87255913e-03,  1.26651973e-01,
        -1.24476865e-01,  6.94983900e-02,  4.27672639e-03,
         1.86136231e-01, -4.54714149e-03,  9.12746191e-02,
         6.12779632e-02],
       [-2.79157639e-01, -4.61089313e-02,  2.51544192e-02,
        -1.79003477e-01,  3.83432880e-02,  2.05054253e-01,
        -8.25636461e-03, -8.25546682e-03,  2.41342247e-01,
         8.24805871e-02],
       [-1.42795354e-01,  6.54597580e-02,  2.05058958e-02,
        -1.28471941e-01,  1.10977650e-01,  4.51317504e-02,
         2.44124904e-01,  1.90523565e-02,  3.11958641e-02,
         6.49511665e-02],
       [-1.33037239e-01, -2.72594951e-02,  8.09026062e-02,
        -1.95883229e-01,  1.84634060e-01,  1.00822970e-01,
         4.40884084e-02, -6.43826872e-02,  1.47807434e-01,
        -1.92791894e-02],
       [-1.43770471e-01, -2.53150351e-02,  4.18904647e-02,
        -1.02573663e-01,  6.15917407e-02,  7.95702711e-02,
         9.27314460e-02, -4.31537181e-02,  4.59018350e-02,
         1.02965936e-01],
       [-1.90395206e-01,  2.93233991e-03,  1.48900077e-02,
        -1.15877971e-01,  1.06598288e-02,  1.40121073e-01,
         6.86443001e-02, -4.61921766e-02,  1.27470195e-01,
         6.73005953e-02],
       [-2.60747373e-01, -1.45188004e-01,  7.10044056e-04,
        -1.04602516e-01,  5.00324890e-02,  2.96664417e-01,
         8.57191086e-02,  6.65097907e-02,  1.31302923e-01,
        -1.84605196e-02],
       [-1.62942797e-01, -3.63466889e-02, -1.33987352e-01,
        -1.34576231e-01, -8.19503814e-02,  1.30840242e-02,
         6.16783127e-02, -3.64837795e-02,  3.18005830e-02,
         1.98420882e-01],
       [-1.25772715e-01, -6.94367215e-02, -1.35144517e-02,
        -6.30265176e-02,  8.36028308e-02,  2.96559408e-02,
         2.19864860e-01, -7.08417147e-02,  4.76131588e-02,
         1.15781695e-01],
       [-1.55139655e-01, -1.27863720e-01,  9.67459157e-02,
        -1.48635745e-01,  1.25129193e-01,  4.04443927e-02,
         2.94884086e-01, -7.66484886e-02,  1.18753463e-01,
         2.93397382e-02],
       [-1.59221828e-01, -9.30457860e-02,  9.18259323e-02,
        -1.72857821e-01,  8.09611157e-02,  1.11391053e-01,
         1.66679412e-01,  3.52456123e-02,  9.05358568e-02,
         9.89414975e-02],
       [-2.01425552e-01, -4.67008501e-02, -1.62331611e-02,
        -9.73629057e-02,  1.36456266e-01,  1.30628154e-01,
         1.53577864e-01, -6.73157908e-03,  9.31103677e-02,
         1.50734074e-02],
       [-1.29348308e-01, -3.03804129e-03,  2.82487050e-02,
        -2.02886015e-01,  7.09105879e-02,  1.74542382e-01,
         2.57992335e-02, -1.63579211e-02,  2.30892301e-02,
         6.69767857e-02],
       [-1.56857669e-01,  5.46110943e-02, -5.93251809e-02,
        -1.04585059e-01,  2.61763521e-02,  1.43062070e-01,
         1.57771498e-01, -6.19823262e-02,  3.59585434e-02,
         6.62322640e-02],
       [-8.64257440e-02, -1.33483298e-03,  7.46414512e-02,
        -1.82848468e-01,  1.21074423e-01,  1.55276239e-01,
         1.46483868e-01, -6.22515939e-03,  1.91641584e-01,
        -9.95825827e-02],
       [-2.52117336e-01, -6.92471862e-02,  1.09911412e-01,
        -3.73112522e-02,  3.76211852e-03,  5.23591004e-02,
         9.16506499e-02,  6.80204183e-02, -4.27842364e-02,
         7.91264027e-02],
       [-2.11018056e-01,  5.97522780e-03,  8.47486481e-02,
        -7.27925971e-02,  9.36664082e-03,  1.62506998e-01,
         5.32426499e-02,  1.78599171e-02, -2.30420940e-02,
         4.07365486e-02],
       [-1.35342121e-01, -4.06659022e-02, -2.09493563e-02,
        -1.64699793e-01,  8.35808069e-02,  7.68100768e-02,
        -7.14773983e-02, -3.43702435e-02,  9.47649628e-02,
         9.36352089e-02],
       [-1.20486066e-01,  3.77080180e-02,  1.14158325e-01,
        -6.50681928e-02,  1.03382617e-02,  1.17891498e-01,
         1.13154747e-01, -1.49052702e-02,  1.28893867e-01,
         1.12219512e-01],
       [-2.23867983e-01, -9.79400948e-02,  7.37103820e-02,
        -1.05197895e-02,  3.75595838e-02,  1.80490598e-01,
         6.83145374e-02, -3.09509300e-02,  1.42565176e-01,
         8.05927664e-02],
       [-2.32092351e-01, -3.42734642e-02, -5.15977889e-02,
        -1.75458089e-01,  1.46448284e-01,  1.80426955e-01,
         1.52164772e-01, -2.57370695e-02,  1.26812875e-01,
         1.22049123e-01],
       [-9.45013613e-02,  5.85526973e-02,  1.47456676e-02,
        -4.40606587e-02,  4.86647561e-02,  6.28624633e-02,
         3.69989276e-02, -3.68277319e-02,  3.56127135e-02,
         3.10502797e-02],
       [-1.02712311e-01,  3.16979140e-02,  1.88253060e-01,
        -5.99608906e-02,  3.73450294e-02,  6.38176724e-02,
         1.12240583e-01,  2.42183693e-02,  1.45670772e-02,
        -9.52028483e-03],
       [-1.62333213e-02, -1.42737105e-02, -5.79352975e-02,
        -1.01807326e-01, -7.93362781e-03, -7.22003728e-02,
         1.49934232e-01, -1.19943202e-01,  9.22369361e-02,
         1.46321565e-01],
       [-1.32534593e-01,  1.18380897e-02,  2.23980099e-03,
        -9.28303748e-02, -2.20538303e-02,  7.68908709e-02,
         5.29715866e-02, -3.43324393e-02, -1.27909705e-02,
        -7.04141408e-02],
       [-8.10261145e-02, -8.95578321e-03,  3.96864787e-02,
        -1.21861629e-01,  7.98310041e-02,  1.56087667e-01,
         9.11872089e-02, -2.29295418e-02,  5.64432219e-02,
        -3.55931222e-02],
       [-1.76416740e-01,  1.12043694e-02, -1.80068091e-02,
        -1.88012689e-01,  8.68914276e-02,  1.57958359e-01,
         5.77907935e-02, -2.12088451e-02,  5.33877537e-02,
         2.19271183e-02],
       [-2.70012528e-01, -1.26611829e-01,  3.10387388e-02,
        -7.24840909e-02,  1.03253610e-01,  8.91268626e-02,
         1.38662308e-01, -6.25240132e-02,  2.36210316e-01,
         1.40534222e-01],
       [-8.52961093e-02, -1.15273651e-02, -2.88792588e-02,
        -2.01282576e-02,  5.43357767e-02,  7.14191943e-02,
         3.46604213e-02, -6.00920171e-02,  5.11362031e-02,
         3.58160883e-02],
       [-1.63262367e-01,  2.44849995e-02,  3.81964818e-02,
        -3.93010303e-02,  3.95263731e-03,  9.11088511e-02,
         3.88236046e-02,  1.33745335e-02,  1.00076631e-01,
         6.05135933e-02],
       [-3.01809371e-01, -1.58440098e-01,  4.65333983e-02,
        -1.63946241e-01, -6.42775744e-02,  3.93286347e-04,
         2.82839835e-01, -8.93663988e-02,  1.97781295e-01,
         2.87044942e-01],
       [-2.15368003e-01, -4.83291782e-02, -8.29075277e-03,
        -1.01776704e-01,  1.43144801e-02,  1.82002857e-02,
         2.76539754e-02, -1.94141679e-02,  8.87098238e-02,
         6.60644472e-02],
       [-2.20715180e-01, -7.20694065e-02, -6.08972833e-02,
        -4.82957587e-02,  1.28858402e-01,  1.30042464e-01,
         1.32807568e-01, -7.52742141e-02,  9.51702446e-02,
         3.10119465e-02],
       [-1.09407350e-01, -5.27948700e-03,  1.29588693e-03,
        -2.61662379e-02,  3.01920641e-02,  1.13487415e-01,
         8.23267922e-02,  1.92574020e-02,  2.31986474e-02,
         4.13139611e-02],
       [-2.12277412e-01, -1.35507256e-01,  4.22930568e-02,
        -1.34565741e-01,  1.17879853e-01,  1.30573064e-01,
         1.81054786e-01, -1.70722306e-01,  1.05854876e-01,
         7.36362934e-02],
       [-1.78249478e-01, -7.55607188e-02,  7.75147527e-02,
        -2.14659080e-01,  3.26948166e-02,  7.76198730e-02,
         1.08791113e-01, -2.38809325e-02,  1.79410487e-01,
         1.94452941e-01],
       [-1.92162693e-01, -1.50472090e-01, -8.24331492e-02,
        -1.40473023e-02,  3.60646360e-02, -9.39090401e-02,
         1.83859855e-01, -1.09493822e-01, -3.09051797e-02,
         1.36017531e-01],
       [-9.21519399e-02, -1.53335631e-02, -5.56742400e-02,
        -9.68495384e-02,  2.35293470e-02,  2.53665410e-02,
         1.79999322e-01, -7.10204691e-02, -7.29817525e-02,
         4.50368747e-02],
       [-1.22261971e-01, -6.94630146e-02, -7.97796808e-03,
        -1.03088826e-01, -7.38603100e-02,  1.84892826e-02,
         9.76646394e-02, -3.29037756e-02, -1.77134499e-02,
         1.62288889e-01],
       [-6.78652674e-02, -1.08500615e-01,  5.66991530e-02,
        -9.52370912e-02,  5.28126955e-02,  1.05176866e-02,
         1.73085481e-01, -1.37753151e-02,  1.95556954e-02,
         1.38068855e-01],
       [-2.02808753e-01, -3.39423120e-02,  1.82233751e-03,
        -5.71424365e-02,  3.40205729e-02,  8.74454305e-02,
         8.47227685e-03, -2.52498202e-02,  4.66104299e-02,
         1.10718749e-01],
       [-9.52449068e-02, -3.35062481e-02, -1.00178778e-01,
        -9.72513855e-02, -3.58061343e-02,  3.04423086e-02,
         5.70362583e-02, -4.03833576e-02, -4.28436548e-02,
         9.73245874e-02],
       [-2.06081957e-01, -1.71493232e-01,  2.52560824e-02,
        -1.55212343e-01, -4.33478206e-02,  2.34177694e-01,
         8.46128762e-02,  1.75322518e-02,  2.04347119e-01,
         1.54971585e-01],
       [-1.95310384e-01,  1.30968075e-02, -9.68117267e-03,
        -7.31432810e-02,  1.02618083e-01,  1.59629256e-01,
         1.66028887e-01, -7.12903216e-03,  1.78021699e-01,
        -2.17130631e-02],
       [-1.59163624e-01, -1.77137554e-05,  1.75410658e-02,
        -9.08103511e-02,  7.25786015e-02,  9.21041369e-02,
         1.24915361e-01, -6.55939505e-02, -1.13440230e-02,
         1.03661232e-01],
       [-1.93366870e-01, -4.36344892e-02,  1.37750164e-01,
        -1.91939399e-01, -1.50268525e-03,  8.03942382e-02,
         2.15812266e-01,  5.38492575e-02,  1.36685073e-01,
         2.22119391e-01],
       [-1.65946245e-01,  7.89588690e-03, -1.65037125e-01,
        -1.23690292e-01, -8.57629776e-02, -2.55736727e-02,
         1.67541012e-01, -6.63827211e-02,  2.98694819e-02,
         1.71927184e-01],
       [-1.56264767e-01, -1.72245800e-02, -4.98924702e-02,
        -2.98387632e-02,  2.80477256e-02,  4.94132042e-02,
         4.89805043e-02,  1.96998678e-02, -4.14144360e-02,
        -5.05549274e-02],
       [-1.46449029e-01, -1.12528354e-01, -4.66653258e-02,
        -3.78398523e-02,  7.60737807e-03, -2.70657167e-02,
         1.11277811e-01,  6.37479573e-02, -2.39458829e-02,
         1.22067556e-01],
       [-1.92323536e-01, -1.43002480e-01,  5.29062748e-03,
        -1.70663983e-01,  8.39572400e-03,  6.37906119e-02,
         1.24084033e-01,  6.02792688e-02,  7.18353763e-02,
         5.03963791e-03],
       [-1.70977920e-01,  1.04207098e-02,  1.18544906e-01,
        -4.29532528e-02, -3.53983864e-02,  1.80302024e-01,
         8.08775946e-02,  3.19045782e-02,  2.52931342e-02,
         1.29424319e-01],
       [-2.13301033e-01, -6.96119964e-02,  2.32847631e-02,
        -7.73920864e-02,  1.10387571e-01,  1.13307782e-01,
         1.41805351e-01, -5.19381016e-02,  1.15313083e-01,
         1.40049949e-01],
       [-1.71651557e-01, -5.98860830e-02, -3.92800570e-03,
        -1.04376137e-01,  7.78115019e-02,  6.84583709e-02,
         2.51923770e-01, -1.05199262e-01,  1.64517179e-01,
         2.18875334e-01],
       [-2.60777414e-01, -8.93031508e-02,  1.27723843e-01,
        -1.97950065e-01,  1.19145498e-01,  7.30907321e-02,
         2.23771721e-01, -6.83849230e-02,  3.68930906e-01,
         1.86811388e-01],
       [-2.38028213e-01,  1.11199915e-03,  2.25015372e-01,
         8.22724327e-02, -1.14511400e-01,  1.57513067e-01,
         5.22858277e-02,  2.13724375e-03,  3.15639377e-02,
         2.08704025e-01],
       [-1.46687120e-01, -1.10313833e-01, -1.16352811e-02,
        -1.44550815e-01,  2.09794566e-02,  1.47883072e-02,
         3.96856442e-02, -2.15019658e-03, -4.90810722e-02,
         1.34708211e-01],
       [-2.02591017e-01, -2.29728431e-01,  6.73423260e-02,
        -1.24901496e-01, -1.38434023e-02,  8.64367038e-02,
         1.22342721e-01,  1.67826824e-02,  1.65354639e-01,
         1.83434993e-01],
       [-2.25799978e-01, -1.02682747e-01,  9.48531851e-02,
        -9.38871950e-02,  1.03806734e-01,  2.04695478e-01,
         8.09893832e-02, -1.45416632e-02,  1.33486420e-01,
        -6.27665371e-02],
       [-1.19375348e-01,  2.23235339e-02,  1.04302749e-01,
        -1.11149743e-01,  6.12434298e-02,  6.89433664e-02,
         2.08741099e-01, -3.81497070e-02, -1.42122135e-02,
         7.65201449e-03]], dtype=float32)>}
2022-01-26 05:41:53.590742: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

شما همچنین می توانید بارگیری و استنتاج را به صورت توزیع شده انجام دهید:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2022-01-26 05:41:53.931428: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

فراخوانی تابع بازیابی شده فقط یک پاس رو به جلو در مدل ذخیره شده (پیش بینی) است. اگر بخواهید به آموزش عملکرد بارگذاری شده ادامه دهید، چه؟ یا تابع بارگذاری شده را در یک مدل بزرگتر جاسازی کنید؟ یک روش معمول این است که برای رسیدن به این هدف، این شیء بارگذاری شده را روی یک لایه Keras قرار دهید. خوشبختانه، TF Hub دارای hub.KerasLayer برای این منظور است که در اینجا نشان داده شده است:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
2022-01-26 05:41:55.594317: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
938/938 [==============================] - 6s 3ms/step - loss: 0.1910 - sparse_categorical_accuracy: 0.9442
Epoch 2/2
938/938 [==============================] - 3s 4ms/step - loss: 0.0633 - sparse_categorical_accuracy: 0.9813

همانطور که می بینید، hub.KerasLayer نتیجه بارگذاری شده از tf.saved_model.load() را در یک لایه Keras قرار می دهد که می تواند برای ساخت مدل دیگری استفاده شود. این برای یادگیری انتقال بسیار مفید است.

از کدام API استفاده کنم؟

برای ذخیره، اگر با یک مدل keras کار می کنید، تقریبا همیشه توصیه می شود از API model.save() Keras استفاده کنید. اگر چیزی که ذخیره می‌کنید مدل Keras نیست، API سطح پایین‌تر تنها انتخاب شماست.

برای بارگذاری، اینکه از کدام API استفاده می کنید بستگی به آنچه می خواهید از API بارگیری دریافت کنید، دارد. اگر نمی توانید (یا نمی خواهید) مدل Keras را دریافت کنید، از tf.saved_model.load() استفاده کنید. در غیر این صورت، از tf.keras.models.load_model() استفاده کنید. توجه داشته باشید که تنها در صورتی می توانید مدل کراس را پس بگیرید که مدل کراس را ذخیره کرده باشید.

امکان ترکیب و تطبیق APIها وجود دارد. می‌توانید یک مدل Keras را با model.save ذخیره کنید و یک مدل غیر Keras را با API سطح پایین tf.saved_model.load کنید.

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

ذخیره / بارگیری از دستگاه محلی

هنگام ذخیره و بارگیری از یک دستگاه io محلی در حین اجرای از راه دور، به عنوان مثال با استفاده از یک TPU ابری، باید از گزینه experimental_io_device برای تنظیم دستگاه io به localhost استفاده شود.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

هشدارها

یک مورد خاص زمانی است که شما یک مدل Keras دارید که ورودی های مشخصی ندارد. به عنوان مثال، یک مدل ترتیبی را می توان بدون هیچ شکل ورودی ایجاد کرد ( Sequential([Dense(3), ...] ). API های سطح پایین تر هم در ذخیره و هم در بارگذاری، در غیر این صورت با خطا مواجه خواهید شد.

برای بررسی اینکه آیا مدل شما دارای ورودی های کاملاً تعریف شده است، فقط بررسی کنید که model.inputs هیچ است یا None . اگر None نباشد، همه شما خوب هستید. هنگامی که مدل در .fit، .fit ، .evaluate یا هنگام فراخوانی مدل ( model(inputs) ) استفاده می شود، اشکال ورودی به طور خودکار تعریف می .predict .

به عنوان مثال:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets