Помогают защитить Большой Барьерный Риф с TensorFlow на Kaggle Присоединяйтесь вызов

Сохраните и загрузите модель, используя стратегию распространения

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Обзор

Обычно во время обучения модель сохраняют и загружают. Существует два набора API для сохранения и загрузки модели keras: API высокого уровня и API низкого уровня. В этом учебнике показано , как можно использовать SavedModel API , при использовании tf.distribute.Strategy . Чтобы узнать о SavedModel и сериализации в целом, пожалуйста , прочитайте сохраненную модель руководства , и модель сериализации руководство Keras . Начнем с простого примера:

Зависимости импорта:

import tensorflow_datasets as tfds

import tensorflow as tf

Подготовка данных и моделирования с помощью tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Обучите модель:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2021-10-26 01:26:36.109959: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 13s 3ms/step - loss: 0.2015 - sparse_categorical_accuracy: 0.9410
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0663 - sparse_categorical_accuracy: 0.9807
<keras.callbacks.History at 0x7fa92037bc90>

Сохраните и загрузите модель

Теперь, когда у вас есть простая модель для работы, давайте взглянем на API сохранения / загрузки. Доступны два набора API:

API Keras

Вот пример сохранения и загрузки модели с помощью API Keras:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2021-10-26 01:26:52.520058: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Восстановление модели без tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0491 - sparse_categorical_accuracy: 0.9851
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0356 - sparse_categorical_accuracy: 0.9890
<keras.callbacks.History at 0x7fa8dc6d6690>

После восстановления модели, вы можете продолжить обучение на нем, даже без необходимости вызова compile() еще раз, так как он уже скомпилированные перед сохранением. Модель сохраняется в стандарте TensorFlow в SavedModel формате прото. Для получения более подробной информации, пожалуйста , обратитесь к руководству к saved_model формата .

Теперь , чтобы загрузить модель и обучать его с помощью tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
2021-10-26 01:26:57.965185: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2021-10-26 01:26:58.004038: W tensorflow/core/framework/dataset.cc:679] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0493 - sparse_categorical_accuracy: 0.9846
Epoch 2/2
938/938 [==============================] - 8s 9ms/step - loss: 0.0345 - sparse_categorical_accuracy: 0.9898

Как вы можете видеть, погрузочные работы , как ожидается , с tf.distribute.Strategy . Используемая здесь стратегия не обязательно должна совпадать со стратегией, использованной до сохранения.

В tf.saved_model APIs

Теперь давайте посмотрим на API нижнего уровня. Сохранение модели аналогично keras API:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

Загрузка может быть сделано с tf.saved_model.load() . Однако, поскольку это API нижнего уровня (и, следовательно, имеет более широкий диапазон вариантов использования), он не возвращает модель Keras. Вместо этого он возвращает объект, содержащий функции, которые можно использовать для вывода. Например:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

Загруженный объект может содержать несколько функций, каждая из которых связана с ключом. "serving_default" является ключом по умолчанию для функции логического вывода с сохраненной моделью Keras. Чтобы сделать вывод с помощью этой функции:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-0.11688858, -0.05038287, -0.2585946 ,  0.04893515,  0.27253783,
         0.1022947 , -0.06840641, -0.33529347, -0.07071295,  0.06517357],
       [ 0.10904782, -0.23611397, -0.16135186,  0.10045648,  0.26082516,
        -0.02260189,  0.0424989 , -0.09468129,  0.05540806,  0.10558474],
       [-0.0491788 , -0.04070761, -0.23004392,  0.17719601,  0.20461476,
        -0.05333536, -0.02240408, -0.21509385, -0.05161493,  0.12337525],
       [ 0.00487803, -0.05770147, -0.23551641,  0.05988425,  0.15881103,
        -0.05608599, -0.04135028, -0.3390705 , -0.07579579, -0.08983649],
       [-0.04663972, -0.13439807, -0.19048163,  0.13628994,  0.05608338,
        -0.06012772, -0.03063064, -0.32014394, -0.16421723,  0.08930477],
       [ 0.02328245,  0.05272574, -0.34110764,  0.12926938,  0.33982378,
         0.12486804, -0.04870659, -0.45755434, -0.05433567,  0.14137071],
       [ 0.06421333, -0.20211999, -0.14309192,  0.00360708,  0.23210834,
         0.00101324, -0.01692696, -0.15713055,  0.00623474, -0.02222142],
       [ 0.08059486,  0.0456247 , -0.15926833,  0.05546484,  0.09179395,
         0.06136999, -0.07209414, -0.2553306 , -0.04975087,  0.06797761],
       [ 0.05864911, -0.10561213, -0.23619679,  0.11069187,  0.13890924,
         0.04969782, -0.05587994, -0.26131746, -0.0363602 ,  0.02788973],
       [ 0.0296779 ,  0.06670297, -0.12159262,  0.06834705,  0.19103828,
         0.14597046,  0.00285575, -0.19362533, -0.06905006,  0.097047  ],
       [ 0.05100356, -0.03875454, -0.31727186,  0.01787528,  0.20725562,
        -0.01677462, -0.00129463, -0.17944467,  0.05812614,  0.04979762],
       [-0.03301986, -0.10880841, -0.21802825,  0.0578297 ,  0.41345048,
         0.10376748,  0.03452782, -0.27389282, -0.06923576,  0.14353925],
       [-0.02203556, -0.08816119, -0.15965816,  0.07572726,  0.018046  ,
        -0.10299203,  0.01126328, -0.21401492, -0.17861444,  0.05669294],
       [-0.0245089 , -0.03849422, -0.2968499 ,  0.23396973,  0.22189453,
         0.00512835, -0.00468208, -0.29407114, -0.14926936, -0.02818882],
       [-0.02376807, -0.05931192, -0.31774518,  0.15711312,  0.31248903,
        -0.04320139, -0.08301807, -0.4610513 , -0.10252888, -0.03784092],
       [-0.03953424, -0.08268867, -0.3604463 ,  0.14048189,  0.33057037,
         0.01373108, -0.12093162, -0.38173944,  0.01771745, -0.07451382],
       [-0.05658644,  0.0519563 , -0.20794927,  0.10203589,  0.2135886 ,
         0.14241108, -0.04007911, -0.26177728, -0.08082938,  0.00216334],
       [-0.06207625, -0.01838757, -0.21708131,  0.10756977,  0.25599915,
         0.03101911,  0.05593228, -0.25550944, -0.11642678,  0.09014311],
       [ 0.05197014,  0.03873106, -0.1469059 ,  0.08044868,  0.12293777,
        -0.00388163,  0.00324975, -0.08145286, -0.12639561, -0.03596487],
       [-0.10676757, -0.05767517, -0.20481907,  0.14739943,  0.17379019,
        -0.08260865, -0.09114882, -0.38688654, -0.1448748 ,  0.03397277],
       [-0.03770879,  0.04663504, -0.30894646,  0.05933709,  0.09536786,
         0.1006383 ,  0.00984312, -0.3204393 , -0.01170056, -0.03391666],
       [ 0.0231554 ,  0.12106506, -0.255493  ,  0.04387057,  0.12491666,
         0.03297757, -0.03934925, -0.17047551,  0.00603533,  0.02295396],
       [-0.0137163 , -0.08226999, -0.3219023 ,  0.1111999 ,  0.15005693,
        -0.10358538, -0.04351711, -0.24015021, -0.08079101,  0.01281704],
       [ 0.08698535, -0.17155564, -0.19832517, -0.0417797 ,  0.24460419,
        -0.00698967,  0.08663791, -0.20004068,  0.02847612,  0.12739052],
       [ 0.0248102 , -0.07629397, -0.10130948,  0.00225735,  0.14270194,
         0.01750292,  0.03144339, -0.1429488 , -0.02819812,  0.24307509],
       [-0.06557162, -0.06485987, -0.36512223,  0.18774748,  0.25643086,
         0.0340823 , -0.01398754, -0.19010906, -0.07261477,  0.05117159],
       [ 0.04187369,  0.0132397 , -0.16233045,  0.10300563,  0.06598518,
         0.05728842, -0.02450454, -0.22889516, -0.03530695,  0.08300389],
       [ 0.15359762, -0.06493542, -0.22839671,  0.05915322,  0.26544052,
         0.15312935, -0.05132065, -0.34682024, -0.0181414 ,  0.08866596],
       [-0.06705338, -0.05590982, -0.21037713,  0.05252159,  0.22411834,
         0.06072947, -0.01180699, -0.31283215, -0.06644081, -0.02687445],
       [-0.01673558, -0.04322004, -0.22221681,  0.11640421,  0.27585298,
        -0.00789917, -0.03705985, -0.12847525, -0.14132528, -0.01258589],
       [ 0.05363014, -0.11879475, -0.08204994,  0.16474688,  0.09248446,
        -0.09719495, -0.07723137, -0.23136492, -0.05618468,  0.10164495],
       [-0.02539362, -0.14454898, -0.32296312,  0.2053542 ,  0.18563472,
        -0.0445538 , -0.13633929, -0.12712947, -0.06732591,  0.05459897],
       [-0.02403368, -0.09293792, -0.22012895,  0.09356467,  0.3415923 ,
        -0.09844425, -0.04539915, -0.28688133, -0.14435257,  0.05483858],
       [ 0.03492264,  0.04167182, -0.08564096,  0.01466741,  0.14968738,
         0.01946784, -0.04962645, -0.09357765, -0.03180797,  0.03431095],
       [ 0.04553585, -0.06386177, -0.159064  ,  0.09195592,  0.20032357,
         0.05248308,  0.05274323, -0.09328806, -0.02849531,  0.10636853],
       [-0.08788846, -0.05706687, -0.27519208,  0.12941426,  0.1730625 ,
         0.00562337,  0.03862702, -0.3364083 ,  0.01087172,  0.03377784],
       [-0.08110045, -0.06666276, -0.34764278,  0.25369477,  0.26242447,
         0.03672977,  0.07488421, -0.11382174,  0.03446682,  0.20799701],
       [-0.02429771, -0.0130821 , -0.28549588,  0.09956603,  0.19093114,
         0.09172641, -0.01084431, -0.26826024, -0.09550276, -0.09001306],
       [-0.0405377 ,  0.02302578, -0.16092977,  0.12650998,  0.10584372,
         0.0598565 ,  0.0370068 , -0.13375495, -0.05769489,  0.04597083],
       [-0.08379065, -0.12666067, -0.23740488,  0.08539408,  0.19100066,
        -0.19001569, -0.03504099, -0.2954648 , -0.00778607, -0.10035929],
       [-0.06841633, -0.02935523, -0.27325606,  0.07019119,  0.13153824,
         0.03444952, -0.07040955, -0.16061744, -0.05776489, -0.02386798],
       [ 0.02282005, -0.03760834, -0.17803052,  0.09008945,  0.15709753,
        -0.02815568, -0.01385967, -0.2636196 , -0.06011615, -0.04417434],
       [ 0.05103182, -0.0073192 , -0.2492007 ,  0.09097242,  0.2589297 ,
         0.03582668, -0.05287637, -0.1023304 , -0.10472505, -0.02360192],
       [-0.04446318, -0.00104156, -0.22680247,  0.0975772 ,  0.25874364,
         0.07281871,  0.14879908, -0.21233654, -0.11104408,  0.1596871 ],
       [-0.16542982, -0.02617702, -0.2530758 ,  0.09354755,  0.19404459,
         0.0228528 , -0.03458656, -0.3274249 , -0.08492248,  0.07104953],
       [-0.04432368, -0.01551367, -0.30958706,  0.08279304,  0.15877493,
         0.14097705,  0.0056034 , -0.2121813 , -0.10417398,  0.13372038],
       [ 0.00872401,  0.02290398, -0.18306321,  0.11926699,  0.0969364 ,
        -0.04007095,  0.01660407, -0.28434896, -0.15929542,  0.01083255],
       [ 0.07433248, -0.14991361, -0.2220522 ,  0.00625274,  0.39078072,
         0.03646233,  0.10941336, -0.20384778, -0.02929106,  0.03544597],
       [-0.00069001, -0.0680518 , -0.11302898,  0.11793397,  0.11893341,
        -0.05947986, -0.02543334, -0.24527295, -0.09240474, -0.00762735],
       [ 0.01683525,  0.03738175, -0.18935157,  0.07978748,  0.23876491,
         0.15589894, -0.00638897, -0.25770593, -0.11232982, -0.0446422 ],
       [-0.01690136, -0.19515185, -0.2338915 , -0.00964288,  0.17318843,
        -0.02175554,  0.07482283, -0.19234088, -0.0229656 ,  0.11406161],
       [-0.00661898,  0.00870193, -0.11167589,  0.15103012,  0.06432639,
        -0.12180559,  0.04999296, -0.2667799 , -0.17659347, -0.04285187],
       [-0.01717829,  0.02375691, -0.14970137,  0.1191919 ,  0.10172842,
        -0.07352136,  0.02696884, -0.11598936, -0.1331213 , -0.00928868],
       [-0.05850236,  0.03356444, -0.24372646,  0.14034908,  0.22228894,
         0.04799255, -0.01023421, -0.23915118, -0.07773915,  0.01665494],
       [-0.04828071, -0.00198432, -0.21945187,  0.14940068,  0.26243302,
         0.04732714, -0.03919668, -0.3767312 , -0.04807761,  0.04837478],
       [ 0.08090632,  0.02816604, -0.31061617,  0.04813545,  0.17886776,
         0.10947818,  0.0324835 , -0.22861008, -0.01619428, -0.00963937],
       [ 0.01237603, -0.07633115, -0.20681188,  0.08626392,  0.16251579,
         0.05692254,  0.00641025, -0.027444  ,  0.05301347,  0.00296039],
       [-0.03114549, -0.03946134, -0.20575103,  0.158873  ,  0.19106835,
        -0.00628418, -0.06812906, -0.29752672, -0.12863883,  0.00519179],
       [-0.02839492,  0.00197193, -0.38123846,  0.12928526,  0.4360217 ,
         0.06745887, -0.01924693, -0.3610945 ,  0.02880143,  0.00938179],
       [-0.10277586,  0.01430387, -0.24793717, -0.02120358,  0.20257095,
         0.10856566,  0.08017994, -0.21743834,  0.02736677,  0.01270235],
       [ 0.00209297, -0.04658009, -0.10872659,  0.00873713,  0.12002683,
        -0.01763269,  0.00062436, -0.07574805,  0.00423002,  0.09696378],
       [-0.0030484 ,  0.00373926, -0.20884912,  0.03331832,  0.37477142,
         0.14008212,  0.031428  , -0.40348598, -0.02555457,  0.05203115],
       [ 0.06917666, -0.07515088, -0.15344585,  0.08451273,  0.16555418,
        -0.00663652, -0.03506049, -0.19360425, -0.01485892, -0.1411201 ],
       [ 0.08957651, -0.0336723 , -0.16066113,  0.09386282,  0.21388392,
        -0.01653587, -0.02893457, -0.04395334, -0.03723653,  0.07710503]],
      dtype=float32)>}
2021-10-26 01:27:16.715879: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Вы также можете загружать и выполнять логический вывод распределенным образом:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2021-10-26 01:27:16.888897: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Вызов восстановленной функции - это просто прямой переход к сохраненной модели (прогноз). Что делать, если вы не хотите продолжать тренировку загруженной функции? Или встроить загруженную функцию в более крупную модель? Обычной практикой является перенос этого загруженного объекта в слой Keras для достижения этой цели. К счастью, TF - концентратор имеет hub.KerasLayer для этой цели, как показано здесь:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2021-10-26 01:27:18.637232: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.2057 - sparse_categorical_accuracy: 0.9392
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0688 - sparse_categorical_accuracy: 0.9802

Как вы можете видеть, hub.KerasLayer оборачивает результат загружается обратно из tf.saved_model.load() в слое Keras , который может быть использован для создания другой модели. Это очень полезно для трансферного обучения.

Какой API мне следует использовать?

Для экономии, если вы работаете с keras модели, она почти всегда рекомендуется использовать в Keras в model.save() API. Если то, что вы сохраняете, не является моделью Keras, тогда ваш единственный выбор - API нижнего уровня.

Какой API вы используете для загрузки, зависит от того, что вы хотите получить от API загрузки. Если вы не можете (или не хотите) получить модель Keras, а затем использовать tf.saved_model.load() . В противном случае, использование tf.keras.models.load_model() . Обратите внимание, что вы можете вернуть модель Keras, только если вы сохранили модель Keras.

Можно смешивать и сопоставлять API. Вы можете сохранить модель Keras с model.save и загрузить модель без Keras с низким уровнем API, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Сохранение / загрузка с локального устройства

При сохранении и загрузке из локального устройства ввода - вывода при работе удаленно, например , с использованием облака ТПУ, вариант experimental_io_device должен использоваться , чтобы установить Io устройство для локального хоста.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Предостережения

Особый случай - это когда у вас есть модель Keras, в которой нет четко определенных входных данных. Например, последовательная модель может быть создана без каких - либо входных форм ( Sequential([Dense(3), ...] ). Подклассы модель также не имеет четко определенные входов после инициализации. В этом случае, вы должны придерживаться с API нижнего уровня как при сохранении, так и при загрузке, иначе вы получите ошибку.

Для того, чтобы проверить , если ваша модель имеет четко определенные входы, просто проверить , если model.inputs не None . Если это не является None , вы все хорошо. Входные формы автоматически определяются , когда модель используется в .fit , .evaluate , .predict , или при вызове модели ( model(inputs) ).

Вот пример:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa4f68ee5d0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa4f68ee5d0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.Dense object at 0x7fa4f68ee490>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.Dense object at 0x7fa4f68ee490>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets