עזרה להגן על שונית המחסום הגדולה עם TensorFlow על Kaggle הצטרפו אתגר

שמור וטען מודל באמצעות אסטרטגיית הפצה

הצג באתר TensorFlow.org הפעל ב-Google Colab צפה במקור ב-GitHub הורד מחברת

סקירה כללית

מקובל לשמור ולטעון דגם במהלך האימון. ישנן שתי קבוצות של ממשקי API לשמירה וטעינה של מודל Keras: API ברמה גבוהה וממשק API ברמה נמוכה. הדרכה זו מדגימה כיצד ניתן להשתמש בממשקי API SavedModel בעת שימוש tf.distribute.Strategy . כדי ללמוד על SavedModel בהמשכים בכלל, אנא קראו את מדריך מודל הציל , ואת מדריך בהמשכים מודל Keras . נתחיל בדוגמה פשוטה:

תלות בייבוא:

import tensorflow_datasets as tfds

import tensorflow as tf

הכן את הנתונים ואת מודל באמצעות tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

אימון הדגם:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2021-10-26 01:26:36.109959: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 13s 3ms/step - loss: 0.2015 - sparse_categorical_accuracy: 0.9410
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0663 - sparse_categorical_accuracy: 0.9807
<keras.callbacks.History at 0x7fa92037bc90>

שמור וטען את הדגם

עכשיו, כשיש לך מודל פשוט לעבוד איתו, בואו נסתכל על ממשקי ה-API של שמירה/טעינה. קיימות שתי קבוצות של ממשקי API זמינים:

ממשקי API של Keras

הנה דוגמה לשמירה וטעינה של מודל עם ממשקי API של Keras:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2021-10-26 01:26:52.520058: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

שחזר את המודל ללא tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0491 - sparse_categorical_accuracy: 0.9851
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0356 - sparse_categorical_accuracy: 0.9890
<keras.callbacks.History at 0x7fa8dc6d6690>

לאחר שחזור המודל, אתה יכול להמשיך ולהתאמן על זה, אפילו ללא צורך בשיחת compile() שוב, שכן הוא כבר נערך לפני השמירה. המודל נשמר בתקן של TensorFlow SavedModel בפורמט פרוטו. לקבלת מידע נוסף, עיינו במדריך כדי saved_model פורמט .

עכשיו לטעון את המודל ולאמן אותו באמצעות tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
2021-10-26 01:26:57.965185: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2021-10-26 01:26:58.004038: W tensorflow/core/framework/dataset.cc:679] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0493 - sparse_categorical_accuracy: 0.9846
Epoch 2/2
938/938 [==============================] - 8s 9ms/step - loss: 0.0345 - sparse_categorical_accuracy: 0.9898

כפי שאתה יכול לראות, עבודות העמסה כצפוי עם tf.distribute.Strategy . האסטרטגיה המשמשת כאן לא חייבת להיות אותה האסטרטגיה שבה נעשה שימוש לפני השמירה.

tf.saved_model APIs

עכשיו בואו נסתכל על ממשקי ה-API ברמה נמוכה יותר. שמירת הדגם דומה ל-Keras API:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

טוען ניתן לעשות עם tf.saved_model.load() . עם זאת, מכיוון שמדובר ב-API שנמצא ברמה הנמוכה יותר (ולכן יש לו מגוון רחב יותר של מקרי שימוש), הוא לא מחזיר דגם של Keras. במקום זאת, הוא מחזיר אובייקט שמכיל פונקציות שניתן להשתמש בהן כדי להסיק. לדוגמה:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

האובייקט הנטען עשוי להכיל מספר פונקציות, כל אחת משויכת למפתח. "serving_default" הוא מפתח ברירת המחדל עבור פונקצית ההיקש עם מודל Keras הציל. כדי לעשות מסקנות עם הפונקציה הזו:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-0.11688858, -0.05038287, -0.2585946 ,  0.04893515,  0.27253783,
         0.1022947 , -0.06840641, -0.33529347, -0.07071295,  0.06517357],
       [ 0.10904782, -0.23611397, -0.16135186,  0.10045648,  0.26082516,
        -0.02260189,  0.0424989 , -0.09468129,  0.05540806,  0.10558474],
       [-0.0491788 , -0.04070761, -0.23004392,  0.17719601,  0.20461476,
        -0.05333536, -0.02240408, -0.21509385, -0.05161493,  0.12337525],
       [ 0.00487803, -0.05770147, -0.23551641,  0.05988425,  0.15881103,
        -0.05608599, -0.04135028, -0.3390705 , -0.07579579, -0.08983649],
       [-0.04663972, -0.13439807, -0.19048163,  0.13628994,  0.05608338,
        -0.06012772, -0.03063064, -0.32014394, -0.16421723,  0.08930477],
       [ 0.02328245,  0.05272574, -0.34110764,  0.12926938,  0.33982378,
         0.12486804, -0.04870659, -0.45755434, -0.05433567,  0.14137071],
       [ 0.06421333, -0.20211999, -0.14309192,  0.00360708,  0.23210834,
         0.00101324, -0.01692696, -0.15713055,  0.00623474, -0.02222142],
       [ 0.08059486,  0.0456247 , -0.15926833,  0.05546484,  0.09179395,
         0.06136999, -0.07209414, -0.2553306 , -0.04975087,  0.06797761],
       [ 0.05864911, -0.10561213, -0.23619679,  0.11069187,  0.13890924,
         0.04969782, -0.05587994, -0.26131746, -0.0363602 ,  0.02788973],
       [ 0.0296779 ,  0.06670297, -0.12159262,  0.06834705,  0.19103828,
         0.14597046,  0.00285575, -0.19362533, -0.06905006,  0.097047  ],
       [ 0.05100356, -0.03875454, -0.31727186,  0.01787528,  0.20725562,
        -0.01677462, -0.00129463, -0.17944467,  0.05812614,  0.04979762],
       [-0.03301986, -0.10880841, -0.21802825,  0.0578297 ,  0.41345048,
         0.10376748,  0.03452782, -0.27389282, -0.06923576,  0.14353925],
       [-0.02203556, -0.08816119, -0.15965816,  0.07572726,  0.018046  ,
        -0.10299203,  0.01126328, -0.21401492, -0.17861444,  0.05669294],
       [-0.0245089 , -0.03849422, -0.2968499 ,  0.23396973,  0.22189453,
         0.00512835, -0.00468208, -0.29407114, -0.14926936, -0.02818882],
       [-0.02376807, -0.05931192, -0.31774518,  0.15711312,  0.31248903,
        -0.04320139, -0.08301807, -0.4610513 , -0.10252888, -0.03784092],
       [-0.03953424, -0.08268867, -0.3604463 ,  0.14048189,  0.33057037,
         0.01373108, -0.12093162, -0.38173944,  0.01771745, -0.07451382],
       [-0.05658644,  0.0519563 , -0.20794927,  0.10203589,  0.2135886 ,
         0.14241108, -0.04007911, -0.26177728, -0.08082938,  0.00216334],
       [-0.06207625, -0.01838757, -0.21708131,  0.10756977,  0.25599915,
         0.03101911,  0.05593228, -0.25550944, -0.11642678,  0.09014311],
       [ 0.05197014,  0.03873106, -0.1469059 ,  0.08044868,  0.12293777,
        -0.00388163,  0.00324975, -0.08145286, -0.12639561, -0.03596487],
       [-0.10676757, -0.05767517, -0.20481907,  0.14739943,  0.17379019,
        -0.08260865, -0.09114882, -0.38688654, -0.1448748 ,  0.03397277],
       [-0.03770879,  0.04663504, -0.30894646,  0.05933709,  0.09536786,
         0.1006383 ,  0.00984312, -0.3204393 , -0.01170056, -0.03391666],
       [ 0.0231554 ,  0.12106506, -0.255493  ,  0.04387057,  0.12491666,
         0.03297757, -0.03934925, -0.17047551,  0.00603533,  0.02295396],
       [-0.0137163 , -0.08226999, -0.3219023 ,  0.1111999 ,  0.15005693,
        -0.10358538, -0.04351711, -0.24015021, -0.08079101,  0.01281704],
       [ 0.08698535, -0.17155564, -0.19832517, -0.0417797 ,  0.24460419,
        -0.00698967,  0.08663791, -0.20004068,  0.02847612,  0.12739052],
       [ 0.0248102 , -0.07629397, -0.10130948,  0.00225735,  0.14270194,
         0.01750292,  0.03144339, -0.1429488 , -0.02819812,  0.24307509],
       [-0.06557162, -0.06485987, -0.36512223,  0.18774748,  0.25643086,
         0.0340823 , -0.01398754, -0.19010906, -0.07261477,  0.05117159],
       [ 0.04187369,  0.0132397 , -0.16233045,  0.10300563,  0.06598518,
         0.05728842, -0.02450454, -0.22889516, -0.03530695,  0.08300389],
       [ 0.15359762, -0.06493542, -0.22839671,  0.05915322,  0.26544052,
         0.15312935, -0.05132065, -0.34682024, -0.0181414 ,  0.08866596],
       [-0.06705338, -0.05590982, -0.21037713,  0.05252159,  0.22411834,
         0.06072947, -0.01180699, -0.31283215, -0.06644081, -0.02687445],
       [-0.01673558, -0.04322004, -0.22221681,  0.11640421,  0.27585298,
        -0.00789917, -0.03705985, -0.12847525, -0.14132528, -0.01258589],
       [ 0.05363014, -0.11879475, -0.08204994,  0.16474688,  0.09248446,
        -0.09719495, -0.07723137, -0.23136492, -0.05618468,  0.10164495],
       [-0.02539362, -0.14454898, -0.32296312,  0.2053542 ,  0.18563472,
        -0.0445538 , -0.13633929, -0.12712947, -0.06732591,  0.05459897],
       [-0.02403368, -0.09293792, -0.22012895,  0.09356467,  0.3415923 ,
        -0.09844425, -0.04539915, -0.28688133, -0.14435257,  0.05483858],
       [ 0.03492264,  0.04167182, -0.08564096,  0.01466741,  0.14968738,
         0.01946784, -0.04962645, -0.09357765, -0.03180797,  0.03431095],
       [ 0.04553585, -0.06386177, -0.159064  ,  0.09195592,  0.20032357,
         0.05248308,  0.05274323, -0.09328806, -0.02849531,  0.10636853],
       [-0.08788846, -0.05706687, -0.27519208,  0.12941426,  0.1730625 ,
         0.00562337,  0.03862702, -0.3364083 ,  0.01087172,  0.03377784],
       [-0.08110045, -0.06666276, -0.34764278,  0.25369477,  0.26242447,
         0.03672977,  0.07488421, -0.11382174,  0.03446682,  0.20799701],
       [-0.02429771, -0.0130821 , -0.28549588,  0.09956603,  0.19093114,
         0.09172641, -0.01084431, -0.26826024, -0.09550276, -0.09001306],
       [-0.0405377 ,  0.02302578, -0.16092977,  0.12650998,  0.10584372,
         0.0598565 ,  0.0370068 , -0.13375495, -0.05769489,  0.04597083],
       [-0.08379065, -0.12666067, -0.23740488,  0.08539408,  0.19100066,
        -0.19001569, -0.03504099, -0.2954648 , -0.00778607, -0.10035929],
       [-0.06841633, -0.02935523, -0.27325606,  0.07019119,  0.13153824,
         0.03444952, -0.07040955, -0.16061744, -0.05776489, -0.02386798],
       [ 0.02282005, -0.03760834, -0.17803052,  0.09008945,  0.15709753,
        -0.02815568, -0.01385967, -0.2636196 , -0.06011615, -0.04417434],
       [ 0.05103182, -0.0073192 , -0.2492007 ,  0.09097242,  0.2589297 ,
         0.03582668, -0.05287637, -0.1023304 , -0.10472505, -0.02360192],
       [-0.04446318, -0.00104156, -0.22680247,  0.0975772 ,  0.25874364,
         0.07281871,  0.14879908, -0.21233654, -0.11104408,  0.1596871 ],
       [-0.16542982, -0.02617702, -0.2530758 ,  0.09354755,  0.19404459,
         0.0228528 , -0.03458656, -0.3274249 , -0.08492248,  0.07104953],
       [-0.04432368, -0.01551367, -0.30958706,  0.08279304,  0.15877493,
         0.14097705,  0.0056034 , -0.2121813 , -0.10417398,  0.13372038],
       [ 0.00872401,  0.02290398, -0.18306321,  0.11926699,  0.0969364 ,
        -0.04007095,  0.01660407, -0.28434896, -0.15929542,  0.01083255],
       [ 0.07433248, -0.14991361, -0.2220522 ,  0.00625274,  0.39078072,
         0.03646233,  0.10941336, -0.20384778, -0.02929106,  0.03544597],
       [-0.00069001, -0.0680518 , -0.11302898,  0.11793397,  0.11893341,
        -0.05947986, -0.02543334, -0.24527295, -0.09240474, -0.00762735],
       [ 0.01683525,  0.03738175, -0.18935157,  0.07978748,  0.23876491,
         0.15589894, -0.00638897, -0.25770593, -0.11232982, -0.0446422 ],
       [-0.01690136, -0.19515185, -0.2338915 , -0.00964288,  0.17318843,
        -0.02175554,  0.07482283, -0.19234088, -0.0229656 ,  0.11406161],
       [-0.00661898,  0.00870193, -0.11167589,  0.15103012,  0.06432639,
        -0.12180559,  0.04999296, -0.2667799 , -0.17659347, -0.04285187],
       [-0.01717829,  0.02375691, -0.14970137,  0.1191919 ,  0.10172842,
        -0.07352136,  0.02696884, -0.11598936, -0.1331213 , -0.00928868],
       [-0.05850236,  0.03356444, -0.24372646,  0.14034908,  0.22228894,
         0.04799255, -0.01023421, -0.23915118, -0.07773915,  0.01665494],
       [-0.04828071, -0.00198432, -0.21945187,  0.14940068,  0.26243302,
         0.04732714, -0.03919668, -0.3767312 , -0.04807761,  0.04837478],
       [ 0.08090632,  0.02816604, -0.31061617,  0.04813545,  0.17886776,
         0.10947818,  0.0324835 , -0.22861008, -0.01619428, -0.00963937],
       [ 0.01237603, -0.07633115, -0.20681188,  0.08626392,  0.16251579,
         0.05692254,  0.00641025, -0.027444  ,  0.05301347,  0.00296039],
       [-0.03114549, -0.03946134, -0.20575103,  0.158873  ,  0.19106835,
        -0.00628418, -0.06812906, -0.29752672, -0.12863883,  0.00519179],
       [-0.02839492,  0.00197193, -0.38123846,  0.12928526,  0.4360217 ,
         0.06745887, -0.01924693, -0.3610945 ,  0.02880143,  0.00938179],
       [-0.10277586,  0.01430387, -0.24793717, -0.02120358,  0.20257095,
         0.10856566,  0.08017994, -0.21743834,  0.02736677,  0.01270235],
       [ 0.00209297, -0.04658009, -0.10872659,  0.00873713,  0.12002683,
        -0.01763269,  0.00062436, -0.07574805,  0.00423002,  0.09696378],
       [-0.0030484 ,  0.00373926, -0.20884912,  0.03331832,  0.37477142,
         0.14008212,  0.031428  , -0.40348598, -0.02555457,  0.05203115],
       [ 0.06917666, -0.07515088, -0.15344585,  0.08451273,  0.16555418,
        -0.00663652, -0.03506049, -0.19360425, -0.01485892, -0.1411201 ],
       [ 0.08957651, -0.0336723 , -0.16066113,  0.09386282,  0.21388392,
        -0.01653587, -0.02893457, -0.04395334, -0.03723653,  0.07710503]],
      dtype=float32)>}
2021-10-26 01:27:16.715879: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

אתה יכול גם לטעון ולעשות מסקנות בצורה מבוזרת:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2021-10-26 01:27:16.888897: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

קריאה לפונקציה המשוחזרת היא רק העברה קדימה בדגם השמור (חיזוי). מה אם אתה רוצה להמשיך לאמן את הפונקציה העמוסה? או להטמיע את הפונקציה הטעונה בדגם גדול יותר? נוהג נפוץ הוא לעטוף את האובייקט הטעון הזה לשכבת Keras כדי להשיג זאת. למרבה המזל, TF Hub יש hub.KerasLayer למטרה זו, המוצגת כאן:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2021-10-26 01:27:18.637232: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.2057 - sparse_categorical_accuracy: 0.9392
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0688 - sparse_categorical_accuracy: 0.9802

כפי שאתה יכול לראות, hub.KerasLayer עוטפת אחורית טעון התוצאה מן tf.saved_model.load() לשכבת Keras כי ניתן להשתמש כדי לבנות מודל אחר. זה מאוד שימושי ללמידה בהעברה.

באיזה API עלי להשתמש?

להצלה, אם אתה עובד עם מודל keras, זה כמעט תמיד מומלץ להשתמש של Keras model.save() API. אם מה שאתה חוסך הוא לא דגם של Keras, אז ה-API ברמה נמוכה יותר הוא הבחירה היחידה שלך.

לטעינה, באיזה API אתה משתמש תלוי במה שאתה רוצה לקבל מממשק ה-API לטעינה. אם אתה לא יכול (או לא רוצה) לקבל מודל Keras, ולאחר מכן להשתמש tf.saved_model.load() . אחרת, השימוש tf.keras.models.load_model() . שימו לב שאתם יכולים לקבל בחזרה דגם של קרס רק אם שמרתם דגם של קרס.

אפשר לערבב ולהתאים את ממשקי ה-API. אתה יכול לחסוך מודל Keras עם model.save , וטען מודל שאינו Keras עם ה- API ברמה נמוכה, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

שמירה/טעינה ממכשיר מקומי

בעת שמירת וטעינה מהתקן IO המקומית תוך כדי ריצה מרחוק, למשל באמצעות TPU ענן, האפשרות experimental_io_device חייב לשמש כדי להגדיר את מכשיר IO כדי localhost.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

אזהרות

מקרה מיוחד הוא כאשר יש לך דגם של Keras שאין לו כניסות מוגדרות היטב. לדוגמא, מודל סדרתית יכול להיווצר ללא כול צורות קלט ( Sequential([Dense(3), ...] ). מודלי Subclassed גם אין תשומות מוגדרות היטב לאחר אתחול. במקרה זה, אתה צריך להישאר עם ממשקי API ברמה נמוכה יותר גם בשמירה וגם בטעינה, אחרת תקבל שגיאה.

כדי לבדוק אם הדגם שלך יש כניסות מוגדרות היטב, רק לבדוק אם model.inputs הוא None . אם זה לא None , אתה כל טוב. צורות קלט מוגדרות באופן אוטומטי כאשר המודל משמש .fit , .evaluate , .predict , או כאשר קוראים את המודל ( model(inputs) ).

הנה דוגמה:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa4f68ee5d0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa4f68ee5d0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.Dense object at 0x7fa4f68ee490>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.Dense object at 0x7fa4f68ee490>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets