वितरण रणनीति का उपयोग करके मॉडल को सहेजें और लोड करें

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

अवलोकन

प्रशिक्षण के दौरान किसी मॉडल को सहेजना और लोड करना आम बात है। केरस मॉडल को सहेजने और लोड करने के लिए एपीआई के दो सेट हैं: एक उच्च-स्तरीय एपीआई और एक निम्न-स्तरीय एपीआई। यह ट्यूटोरियल दर्शाता है कि tf.distribute.Strategy का उपयोग करते समय आप tf.distribute.Strategy API का उपयोग कैसे कर सकते हैं। सामान्य रूप से सहेजे गए मॉडल और क्रमांकन के बारे में जानने के लिए, कृपया सहेजे गए मॉडल मार्गदर्शिका और केरस मॉडल क्रमांकन मार्गदर्शिका पढ़ें। आइए एक साधारण उदाहरण से शुरू करें:

आयात निर्भरताएँ:

import tensorflow_datasets as tfds

import tensorflow as tf

tf.distribute.Strategy का उपयोग करके डेटा और मॉडल तैयार करें:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

मॉडल को प्रशिक्षित करें:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
2022-01-26 05:41:11.916000: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 11s 5ms/step - loss: 0.1873 - sparse_categorical_accuracy: 0.9451
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0641 - sparse_categorical_accuracy: 0.9807
<keras.callbacks.History at 0x7f3b900396d0>

मॉडल को सहेजें और लोड करें

अब जबकि आपके पास काम करने के लिए एक सरल मॉडल है, तो आइए सेविंग/लोडिंग एपीआई पर एक नजर डालते हैं। एपीआई के दो सेट उपलब्ध हैं:

केरस एपीआई

केरस एपीआई के साथ एक मॉडल को सहेजने और लोड करने का एक उदाहरण यहां दिया गया है:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2022-01-26 05:41:26.593570: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

tf.distribute.Strategy के बिना मॉडल को पुनर्स्थापित करें:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9859
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0334 - sparse_categorical_accuracy: 0.9895
<keras.callbacks.History at 0x7f3b187b7150>

मॉडल को पुनर्स्थापित करने के बाद, आप उस पर प्रशिक्षण जारी रख सकते हैं, यहां तक ​​​​कि compile() को फिर से कॉल करने की आवश्यकता के बिना, क्योंकि यह पहले से ही सहेजने से पहले संकलित है। मॉडल TensorFlow के मानक SavedModel प्रोटो प्रारूप में सहेजा गया है। अधिक जानकारी के लिए, कृपया saved_model प्रारूप की मार्गदर्शिका देखें।

अब मॉडल को लोड करने और इसे tf.distribute.Strategy का उपयोग करके प्रशिक्षित करने के लिए:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2022-01-26 05:41:33.036733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-01-26 05:41:33.083001: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
938/938 [==============================] - 10s 10ms/step - loss: 0.0474 - sparse_categorical_accuracy: 0.9860
Epoch 2/2
938/938 [==============================] - 10s 10ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9903

जैसा कि आप देख सकते हैं, लोडिंग tf.distribute.Strategy के साथ अपेक्षित रूप से काम करता है। जरूरी नहीं कि यहां इस्तेमाल की जाने वाली रणनीति बचत करने से पहले इस्तेमाल की गई रणनीति ही हो।

tf.saved_model APIs

अब निचले स्तर के एपीआई पर एक नजर डालते हैं। मॉडल को सहेजना केरस एपीआई के समान है:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

लोडिंग tf.saved_model.load() के साथ की जा सकती है। हालाँकि, चूंकि यह एक एपीआई है जो निचले स्तर पर है (और इसलिए इसमें उपयोग के मामलों की एक विस्तृत श्रृंखला है), यह एक केरस मॉडल नहीं लौटाता है। इसके बजाय, यह एक ऐसी वस्तु देता है जिसमें ऐसे कार्य होते हैं जिनका उपयोग अनुमान लगाने के लिए किया जा सकता है। उदाहरण के लिए:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

लोड किए गए ऑब्जेक्ट में कई फ़ंक्शन हो सकते हैं, जिनमें से प्रत्येक एक कुंजी से जुड़ा होता है। सहेजे गए केरस मॉडल के साथ अनुमान फ़ंक्शन के लिए "serving_default" डिफ़ॉल्ट कुंजी है। इस फ़ंक्शन के साथ एक अनुमान लगाने के लिए:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-1.18789300e-01, -1.78404614e-01,  4.92432676e-02,
        -9.37875658e-02,  1.14302970e-01, -8.99422392e-02,
         9.47709680e-02, -7.75382966e-02,  4.04430032e-02,
         2.41404288e-02],
       [-2.35370561e-01, -3.39397341e-02,  2.73427293e-02,
        -1.08200148e-01,  5.10682352e-02,  1.36142194e-01,
         9.28785652e-02, -5.35808355e-02,  2.56292164e-01,
         1.05301209e-01],
       [-1.91031799e-01, -7.72745535e-02, -7.23153427e-02,
        -1.99329913e-01, -7.45072216e-02,  2.42738128e-02,
         2.07733169e-01, -3.15396488e-03,  4.95976806e-02,
         2.14848563e-01],
       [-9.82482210e-02, -6.13910556e-02,  1.00815810e-01,
        -1.87558904e-01,  1.14685424e-01,  1.53835595e-01,
         1.85714245e-01, -8.74890238e-02,  1.07493028e-01,
         1.57510787e-02],
       [-8.56257528e-02,  3.23683321e-02, -3.66768315e-02,
        -1.47201523e-01, -5.31517603e-02,  1.52744055e-02,
         1.69184029e-01, -5.42814359e-02,  1.11524366e-01,
         5.65215349e-02],
       [-1.50604844e-01, -7.87255913e-03,  1.26651973e-01,
        -1.24476865e-01,  6.94983900e-02,  4.27672639e-03,
         1.86136231e-01, -4.54714149e-03,  9.12746191e-02,
         6.12779632e-02],
       [-2.79157639e-01, -4.61089313e-02,  2.51544192e-02,
        -1.79003477e-01,  3.83432880e-02,  2.05054253e-01,
        -8.25636461e-03, -8.25546682e-03,  2.41342247e-01,
         8.24805871e-02],
       [-1.42795354e-01,  6.54597580e-02,  2.05058958e-02,
        -1.28471941e-01,  1.10977650e-01,  4.51317504e-02,
         2.44124904e-01,  1.90523565e-02,  3.11958641e-02,
         6.49511665e-02],
       [-1.33037239e-01, -2.72594951e-02,  8.09026062e-02,
        -1.95883229e-01,  1.84634060e-01,  1.00822970e-01,
         4.40884084e-02, -6.43826872e-02,  1.47807434e-01,
        -1.92791894e-02],
       [-1.43770471e-01, -2.53150351e-02,  4.18904647e-02,
        -1.02573663e-01,  6.15917407e-02,  7.95702711e-02,
         9.27314460e-02, -4.31537181e-02,  4.59018350e-02,
         1.02965936e-01],
       [-1.90395206e-01,  2.93233991e-03,  1.48900077e-02,
        -1.15877971e-01,  1.06598288e-02,  1.40121073e-01,
         6.86443001e-02, -4.61921766e-02,  1.27470195e-01,
         6.73005953e-02],
       [-2.60747373e-01, -1.45188004e-01,  7.10044056e-04,
        -1.04602516e-01,  5.00324890e-02,  2.96664417e-01,
         8.57191086e-02,  6.65097907e-02,  1.31302923e-01,
        -1.84605196e-02],
       [-1.62942797e-01, -3.63466889e-02, -1.33987352e-01,
        -1.34576231e-01, -8.19503814e-02,  1.30840242e-02,
         6.16783127e-02, -3.64837795e-02,  3.18005830e-02,
         1.98420882e-01],
       [-1.25772715e-01, -6.94367215e-02, -1.35144517e-02,
        -6.30265176e-02,  8.36028308e-02,  2.96559408e-02,
         2.19864860e-01, -7.08417147e-02,  4.76131588e-02,
         1.15781695e-01],
       [-1.55139655e-01, -1.27863720e-01,  9.67459157e-02,
        -1.48635745e-01,  1.25129193e-01,  4.04443927e-02,
         2.94884086e-01, -7.66484886e-02,  1.18753463e-01,
         2.93397382e-02],
       [-1.59221828e-01, -9.30457860e-02,  9.18259323e-02,
        -1.72857821e-01,  8.09611157e-02,  1.11391053e-01,
         1.66679412e-01,  3.52456123e-02,  9.05358568e-02,
         9.89414975e-02],
       [-2.01425552e-01, -4.67008501e-02, -1.62331611e-02,
        -9.73629057e-02,  1.36456266e-01,  1.30628154e-01,
         1.53577864e-01, -6.73157908e-03,  9.31103677e-02,
         1.50734074e-02],
       [-1.29348308e-01, -3.03804129e-03,  2.82487050e-02,
        -2.02886015e-01,  7.09105879e-02,  1.74542382e-01,
         2.57992335e-02, -1.63579211e-02,  2.30892301e-02,
         6.69767857e-02],
       [-1.56857669e-01,  5.46110943e-02, -5.93251809e-02,
        -1.04585059e-01,  2.61763521e-02,  1.43062070e-01,
         1.57771498e-01, -6.19823262e-02,  3.59585434e-02,
         6.62322640e-02],
       [-8.64257440e-02, -1.33483298e-03,  7.46414512e-02,
        -1.82848468e-01,  1.21074423e-01,  1.55276239e-01,
         1.46483868e-01, -6.22515939e-03,  1.91641584e-01,
        -9.95825827e-02],
       [-2.52117336e-01, -6.92471862e-02,  1.09911412e-01,
        -3.73112522e-02,  3.76211852e-03,  5.23591004e-02,
         9.16506499e-02,  6.80204183e-02, -4.27842364e-02,
         7.91264027e-02],
       [-2.11018056e-01,  5.97522780e-03,  8.47486481e-02,
        -7.27925971e-02,  9.36664082e-03,  1.62506998e-01,
         5.32426499e-02,  1.78599171e-02, -2.30420940e-02,
         4.07365486e-02],
       [-1.35342121e-01, -4.06659022e-02, -2.09493563e-02,
        -1.64699793e-01,  8.35808069e-02,  7.68100768e-02,
        -7.14773983e-02, -3.43702435e-02,  9.47649628e-02,
         9.36352089e-02],
       [-1.20486066e-01,  3.77080180e-02,  1.14158325e-01,
        -6.50681928e-02,  1.03382617e-02,  1.17891498e-01,
         1.13154747e-01, -1.49052702e-02,  1.28893867e-01,
         1.12219512e-01],
       [-2.23867983e-01, -9.79400948e-02,  7.37103820e-02,
        -1.05197895e-02,  3.75595838e-02,  1.80490598e-01,
         6.83145374e-02, -3.09509300e-02,  1.42565176e-01,
         8.05927664e-02],
       [-2.32092351e-01, -3.42734642e-02, -5.15977889e-02,
        -1.75458089e-01,  1.46448284e-01,  1.80426955e-01,
         1.52164772e-01, -2.57370695e-02,  1.26812875e-01,
         1.22049123e-01],
       [-9.45013613e-02,  5.85526973e-02,  1.47456676e-02,
        -4.40606587e-02,  4.86647561e-02,  6.28624633e-02,
         3.69989276e-02, -3.68277319e-02,  3.56127135e-02,
         3.10502797e-02],
       [-1.02712311e-01,  3.16979140e-02,  1.88253060e-01,
        -5.99608906e-02,  3.73450294e-02,  6.38176724e-02,
         1.12240583e-01,  2.42183693e-02,  1.45670772e-02,
        -9.52028483e-03],
       [-1.62333213e-02, -1.42737105e-02, -5.79352975e-02,
        -1.01807326e-01, -7.93362781e-03, -7.22003728e-02,
         1.49934232e-01, -1.19943202e-01,  9.22369361e-02,
         1.46321565e-01],
       [-1.32534593e-01,  1.18380897e-02,  2.23980099e-03,
        -9.28303748e-02, -2.20538303e-02,  7.68908709e-02,
         5.29715866e-02, -3.43324393e-02, -1.27909705e-02,
        -7.04141408e-02],
       [-8.10261145e-02, -8.95578321e-03,  3.96864787e-02,
        -1.21861629e-01,  7.98310041e-02,  1.56087667e-01,
         9.11872089e-02, -2.29295418e-02,  5.64432219e-02,
        -3.55931222e-02],
       [-1.76416740e-01,  1.12043694e-02, -1.80068091e-02,
        -1.88012689e-01,  8.68914276e-02,  1.57958359e-01,
         5.77907935e-02, -2.12088451e-02,  5.33877537e-02,
         2.19271183e-02],
       [-2.70012528e-01, -1.26611829e-01,  3.10387388e-02,
        -7.24840909e-02,  1.03253610e-01,  8.91268626e-02,
         1.38662308e-01, -6.25240132e-02,  2.36210316e-01,
         1.40534222e-01],
       [-8.52961093e-02, -1.15273651e-02, -2.88792588e-02,
        -2.01282576e-02,  5.43357767e-02,  7.14191943e-02,
         3.46604213e-02, -6.00920171e-02,  5.11362031e-02,
         3.58160883e-02],
       [-1.63262367e-01,  2.44849995e-02,  3.81964818e-02,
        -3.93010303e-02,  3.95263731e-03,  9.11088511e-02,
         3.88236046e-02,  1.33745335e-02,  1.00076631e-01,
         6.05135933e-02],
       [-3.01809371e-01, -1.58440098e-01,  4.65333983e-02,
        -1.63946241e-01, -6.42775744e-02,  3.93286347e-04,
         2.82839835e-01, -8.93663988e-02,  1.97781295e-01,
         2.87044942e-01],
       [-2.15368003e-01, -4.83291782e-02, -8.29075277e-03,
        -1.01776704e-01,  1.43144801e-02,  1.82002857e-02,
         2.76539754e-02, -1.94141679e-02,  8.87098238e-02,
         6.60644472e-02],
       [-2.20715180e-01, -7.20694065e-02, -6.08972833e-02,
        -4.82957587e-02,  1.28858402e-01,  1.30042464e-01,
         1.32807568e-01, -7.52742141e-02,  9.51702446e-02,
         3.10119465e-02],
       [-1.09407350e-01, -5.27948700e-03,  1.29588693e-03,
        -2.61662379e-02,  3.01920641e-02,  1.13487415e-01,
         8.23267922e-02,  1.92574020e-02,  2.31986474e-02,
         4.13139611e-02],
       [-2.12277412e-01, -1.35507256e-01,  4.22930568e-02,
        -1.34565741e-01,  1.17879853e-01,  1.30573064e-01,
         1.81054786e-01, -1.70722306e-01,  1.05854876e-01,
         7.36362934e-02],
       [-1.78249478e-01, -7.55607188e-02,  7.75147527e-02,
        -2.14659080e-01,  3.26948166e-02,  7.76198730e-02,
         1.08791113e-01, -2.38809325e-02,  1.79410487e-01,
         1.94452941e-01],
       [-1.92162693e-01, -1.50472090e-01, -8.24331492e-02,
        -1.40473023e-02,  3.60646360e-02, -9.39090401e-02,
         1.83859855e-01, -1.09493822e-01, -3.09051797e-02,
         1.36017531e-01],
       [-9.21519399e-02, -1.53335631e-02, -5.56742400e-02,
        -9.68495384e-02,  2.35293470e-02,  2.53665410e-02,
         1.79999322e-01, -7.10204691e-02, -7.29817525e-02,
         4.50368747e-02],
       [-1.22261971e-01, -6.94630146e-02, -7.97796808e-03,
        -1.03088826e-01, -7.38603100e-02,  1.84892826e-02,
         9.76646394e-02, -3.29037756e-02, -1.77134499e-02,
         1.62288889e-01],
       [-6.78652674e-02, -1.08500615e-01,  5.66991530e-02,
        -9.52370912e-02,  5.28126955e-02,  1.05176866e-02,
         1.73085481e-01, -1.37753151e-02,  1.95556954e-02,
         1.38068855e-01],
       [-2.02808753e-01, -3.39423120e-02,  1.82233751e-03,
        -5.71424365e-02,  3.40205729e-02,  8.74454305e-02,
         8.47227685e-03, -2.52498202e-02,  4.66104299e-02,
         1.10718749e-01],
       [-9.52449068e-02, -3.35062481e-02, -1.00178778e-01,
        -9.72513855e-02, -3.58061343e-02,  3.04423086e-02,
         5.70362583e-02, -4.03833576e-02, -4.28436548e-02,
         9.73245874e-02],
       [-2.06081957e-01, -1.71493232e-01,  2.52560824e-02,
        -1.55212343e-01, -4.33478206e-02,  2.34177694e-01,
         8.46128762e-02,  1.75322518e-02,  2.04347119e-01,
         1.54971585e-01],
       [-1.95310384e-01,  1.30968075e-02, -9.68117267e-03,
        -7.31432810e-02,  1.02618083e-01,  1.59629256e-01,
         1.66028887e-01, -7.12903216e-03,  1.78021699e-01,
        -2.17130631e-02],
       [-1.59163624e-01, -1.77137554e-05,  1.75410658e-02,
        -9.08103511e-02,  7.25786015e-02,  9.21041369e-02,
         1.24915361e-01, -6.55939505e-02, -1.13440230e-02,
         1.03661232e-01],
       [-1.93366870e-01, -4.36344892e-02,  1.37750164e-01,
        -1.91939399e-01, -1.50268525e-03,  8.03942382e-02,
         2.15812266e-01,  5.38492575e-02,  1.36685073e-01,
         2.22119391e-01],
       [-1.65946245e-01,  7.89588690e-03, -1.65037125e-01,
        -1.23690292e-01, -8.57629776e-02, -2.55736727e-02,
         1.67541012e-01, -6.63827211e-02,  2.98694819e-02,
         1.71927184e-01],
       [-1.56264767e-01, -1.72245800e-02, -4.98924702e-02,
        -2.98387632e-02,  2.80477256e-02,  4.94132042e-02,
         4.89805043e-02,  1.96998678e-02, -4.14144360e-02,
        -5.05549274e-02],
       [-1.46449029e-01, -1.12528354e-01, -4.66653258e-02,
        -3.78398523e-02,  7.60737807e-03, -2.70657167e-02,
         1.11277811e-01,  6.37479573e-02, -2.39458829e-02,
         1.22067556e-01],
       [-1.92323536e-01, -1.43002480e-01,  5.29062748e-03,
        -1.70663983e-01,  8.39572400e-03,  6.37906119e-02,
         1.24084033e-01,  6.02792688e-02,  7.18353763e-02,
         5.03963791e-03],
       [-1.70977920e-01,  1.04207098e-02,  1.18544906e-01,
        -4.29532528e-02, -3.53983864e-02,  1.80302024e-01,
         8.08775946e-02,  3.19045782e-02,  2.52931342e-02,
         1.29424319e-01],
       [-2.13301033e-01, -6.96119964e-02,  2.32847631e-02,
        -7.73920864e-02,  1.10387571e-01,  1.13307782e-01,
         1.41805351e-01, -5.19381016e-02,  1.15313083e-01,
         1.40049949e-01],
       [-1.71651557e-01, -5.98860830e-02, -3.92800570e-03,
        -1.04376137e-01,  7.78115019e-02,  6.84583709e-02,
         2.51923770e-01, -1.05199262e-01,  1.64517179e-01,
         2.18875334e-01],
       [-2.60777414e-01, -8.93031508e-02,  1.27723843e-01,
        -1.97950065e-01,  1.19145498e-01,  7.30907321e-02,
         2.23771721e-01, -6.83849230e-02,  3.68930906e-01,
         1.86811388e-01],
       [-2.38028213e-01,  1.11199915e-03,  2.25015372e-01,
         8.22724327e-02, -1.14511400e-01,  1.57513067e-01,
         5.22858277e-02,  2.13724375e-03,  3.15639377e-02,
         2.08704025e-01],
       [-1.46687120e-01, -1.10313833e-01, -1.16352811e-02,
        -1.44550815e-01,  2.09794566e-02,  1.47883072e-02,
         3.96856442e-02, -2.15019658e-03, -4.90810722e-02,
         1.34708211e-01],
       [-2.02591017e-01, -2.29728431e-01,  6.73423260e-02,
        -1.24901496e-01, -1.38434023e-02,  8.64367038e-02,
         1.22342721e-01,  1.67826824e-02,  1.65354639e-01,
         1.83434993e-01],
       [-2.25799978e-01, -1.02682747e-01,  9.48531851e-02,
        -9.38871950e-02,  1.03806734e-01,  2.04695478e-01,
         8.09893832e-02, -1.45416632e-02,  1.33486420e-01,
        -6.27665371e-02],
       [-1.19375348e-01,  2.23235339e-02,  1.04302749e-01,
        -1.11149743e-01,  6.12434298e-02,  6.89433664e-02,
         2.08741099e-01, -3.81497070e-02, -1.42122135e-02,
         7.65201449e-03]], dtype=float32)>}
2022-01-26 05:41:53.590742: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

आप वितरित तरीके से अनुमान भी लोड और कर सकते हैं:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2022-01-26 05:41:53.931428: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

पुनर्स्थापित फ़ंक्शन को कॉल करना सहेजे गए मॉडल (भविष्यवाणी) पर केवल एक आगे का पास है। क्या होगा यदि आप लोड किए गए फ़ंक्शन का प्रशिक्षण जारी रखना चाहते हैं? या लोड किए गए फ़ंक्शन को एक बड़े मॉडल में एम्बेड करें? इसे प्राप्त करने के लिए इस भरी हुई वस्तु को केरस परत में लपेटना एक आम बात है। सौभाग्य से, TF हब में इस उद्देश्य के लिए हब.केरसलेयर है, जो यहाँ दिखाया गया है:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
2022-01-26 05:41:55.594317: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
938/938 [==============================] - 6s 3ms/step - loss: 0.1910 - sparse_categorical_accuracy: 0.9442
Epoch 2/2
938/938 [==============================] - 3s 4ms/step - loss: 0.0633 - sparse_categorical_accuracy: 0.9813

जैसा कि आप देख सकते हैं, hub.KerasLayer tf.saved_model.load() से वापस लोड किए गए परिणाम को एक Keras परत में लपेटता है जिसका उपयोग किसी अन्य मॉडल के निर्माण के लिए किया जा सकता है। यह ट्रांसफर लर्निंग के लिए बहुत उपयोगी है।

मुझे किस एपीआई का उपयोग करना चाहिए?

बचत के लिए, यदि आप केरस मॉडल के साथ काम कर रहे हैं, तो लगभग हमेशा केरस के मॉडल.सेव model.save() एपीआई का उपयोग करने की सिफारिश की जाती है। यदि आप जो सहेज रहे हैं वह केरस मॉडल नहीं है, तो निचले स्तर का एपीआई आपकी एकमात्र पसंद है।

लोडिंग के लिए, आप किस एपीआई का उपयोग करते हैं यह इस बात पर निर्भर करता है कि आप लोडिंग एपीआई से क्या प्राप्त करना चाहते हैं। यदि आप केरस मॉडल प्राप्त नहीं कर सकते (या नहीं करना चाहते), तो tf.saved_model.load() का उपयोग करें। अन्यथा, tf.keras.models.load_model() का उपयोग करें। ध्यान दें कि आप केरस मॉडल को तभी वापस पा सकते हैं जब आपने केरस मॉडल को सहेजा हो।

एपीआई को मिक्स एंड मैच करना संभव है। आप एक केरस मॉडल को model.save के साथ सहेज सकते हैं, और निम्न-स्तरीय एपीआई, tf.saved_model.load के साथ एक गैर-केरस मॉडल लोड कर सकते हैं।

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
प्लेसहोल्डर22

स्थानीय डिवाइस से सहेजना/लोड करना

दूरस्थ रूप से चलते समय स्थानीय आईओ डिवाइस से सहेजते और लोड करते समय, उदाहरण के लिए क्लाउड टीपीयू का उपयोग करते हुए, आईओ डिवाइस को लोकलहोस्ट पर सेट करने के लिए विकल्प experimental_io_device का उपयोग किया जाना चाहिए।

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

चेतावनियां

एक विशेष मामला तब होता है जब आपके पास केरस मॉडल होता है जिसमें अच्छी तरह से परिभाषित इनपुट नहीं होते हैं। उदाहरण के लिए, एक अनुक्रमिक मॉडल बिना किसी इनपुट आकार के बनाया जा सकता है ( Sequential([Dense(3), ...] )। उप-वर्गीकृत मॉडल में आरंभीकरण के बाद भी अच्छी तरह से परिभाषित इनपुट नहीं होते हैं। इस मामले में, आपको साथ रहना चाहिए बचत और लोडिंग दोनों पर निचले स्तर के एपीआई, अन्यथा आपको एक त्रुटि मिलेगी।

यह जांचने के लिए कि क्या आपके मॉडल में अच्छी तरह से परिभाषित इनपुट हैं, बस जांचें कि क्या model.inputs None है। यदि यह None है, तो आप सभी अच्छे हैं। इनपुट आकार स्वचालित रूप से परिभाषित होते हैं जब मॉडल का उपयोग .fit , .evaluate , .predict , या मॉडल ( model(inputs) ) को कॉल करते समय किया जाता है।

यहाँ एक उदाहरण है:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
प्लेसहोल्डर26