Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Guarde y cargue un modelo usando una estrategia de distribución

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Visión general

Es común guardar y cargar un modelo durante el entrenamiento. Hay dos conjuntos de API para guardar y cargar un modelo de keras: una API de alto nivel y una API de bajo nivel. En este tutorial se muestra cómo se pueden utilizar las API SavedModel cuando se utiliza tf.distribute.Strategy . Para aprender sobre SavedModel y la serialización en general, por favor lea la guía modelo guardado , y la guía modelo de serialización Keras . Comencemos con un ejemplo simple:

Importar dependencias:

import tensorflow_datasets as tfds

import tensorflow as tf

Preparar los datos y modelo usando tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Entrena el modelo:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2021-10-26 01:26:36.109959: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 13s 3ms/step - loss: 0.2015 - sparse_categorical_accuracy: 0.9410
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0663 - sparse_categorical_accuracy: 0.9807
<keras.callbacks.History at 0x7fa92037bc90>

Guarde y cargue el modelo

Ahora que tiene un modelo simple con el que trabajar, echemos un vistazo a las API de guardado / carga. Hay dos conjuntos de API disponibles:

Las API de Keras

A continuación, se muestra un ejemplo de cómo guardar y cargar un modelo con las API de Keras:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2021-10-26 01:26:52.520058: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Restaurar el modelo sin tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0491 - sparse_categorical_accuracy: 0.9851
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0356 - sparse_categorical_accuracy: 0.9890
<keras.callbacks.History at 0x7fa8dc6d6690>

Después de restaurar el modelo, puede continuar la formación en él, incluso sin necesidad de convocatoria compile() de nuevo, puesto que ya se compila antes de guardar. El modelo se guarda en el estándar de la TensorFlow SavedModel formato proto. Para obtener más información, consulte la guía para saved_model formato .

Ahora para cargar el modelo y entrenarlo usando un tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
2021-10-26 01:26:57.965185: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2021-10-26 01:26:58.004038: W tensorflow/core/framework/dataset.cc:679] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0493 - sparse_categorical_accuracy: 0.9846
Epoch 2/2
938/938 [==============================] - 8s 9ms/step - loss: 0.0345 - sparse_categorical_accuracy: 0.9898

Como se puede ver, los trabajos de carga como se esperaba con tf.distribute.Strategy . La estrategia utilizada aquí no tiene que ser la misma estrategia utilizada antes de guardar.

Los tf.saved_model API

Ahora echemos un vistazo a las API de nivel inferior. Guardar el modelo es similar a la API de keras:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

La carga puede hacerse con tf.saved_model.load() . Sin embargo, dado que es una API que se encuentra en el nivel inferior (y, por lo tanto, tiene una gama más amplia de casos de uso), no devuelve un modelo de Keras. En cambio, devuelve un objeto que contiene funciones que se pueden usar para hacer inferencias. Por ejemplo:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

El objeto cargado puede contener múltiples funciones, cada una asociada con una tecla. El "serving_default" es la clave por defecto para la función de inferencia con un modelo Keras salvado. Para hacer una inferencia con esta función:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-0.11688858, -0.05038287, -0.2585946 ,  0.04893515,  0.27253783,
         0.1022947 , -0.06840641, -0.33529347, -0.07071295,  0.06517357],
       [ 0.10904782, -0.23611397, -0.16135186,  0.10045648,  0.26082516,
        -0.02260189,  0.0424989 , -0.09468129,  0.05540806,  0.10558474],
       [-0.0491788 , -0.04070761, -0.23004392,  0.17719601,  0.20461476,
        -0.05333536, -0.02240408, -0.21509385, -0.05161493,  0.12337525],
       [ 0.00487803, -0.05770147, -0.23551641,  0.05988425,  0.15881103,
        -0.05608599, -0.04135028, -0.3390705 , -0.07579579, -0.08983649],
       [-0.04663972, -0.13439807, -0.19048163,  0.13628994,  0.05608338,
        -0.06012772, -0.03063064, -0.32014394, -0.16421723,  0.08930477],
       [ 0.02328245,  0.05272574, -0.34110764,  0.12926938,  0.33982378,
         0.12486804, -0.04870659, -0.45755434, -0.05433567,  0.14137071],
       [ 0.06421333, -0.20211999, -0.14309192,  0.00360708,  0.23210834,
         0.00101324, -0.01692696, -0.15713055,  0.00623474, -0.02222142],
       [ 0.08059486,  0.0456247 , -0.15926833,  0.05546484,  0.09179395,
         0.06136999, -0.07209414, -0.2553306 , -0.04975087,  0.06797761],
       [ 0.05864911, -0.10561213, -0.23619679,  0.11069187,  0.13890924,
         0.04969782, -0.05587994, -0.26131746, -0.0363602 ,  0.02788973],
       [ 0.0296779 ,  0.06670297, -0.12159262,  0.06834705,  0.19103828,
         0.14597046,  0.00285575, -0.19362533, -0.06905006,  0.097047  ],
       [ 0.05100356, -0.03875454, -0.31727186,  0.01787528,  0.20725562,
        -0.01677462, -0.00129463, -0.17944467,  0.05812614,  0.04979762],
       [-0.03301986, -0.10880841, -0.21802825,  0.0578297 ,  0.41345048,
         0.10376748,  0.03452782, -0.27389282, -0.06923576,  0.14353925],
       [-0.02203556, -0.08816119, -0.15965816,  0.07572726,  0.018046  ,
        -0.10299203,  0.01126328, -0.21401492, -0.17861444,  0.05669294],
       [-0.0245089 , -0.03849422, -0.2968499 ,  0.23396973,  0.22189453,
         0.00512835, -0.00468208, -0.29407114, -0.14926936, -0.02818882],
       [-0.02376807, -0.05931192, -0.31774518,  0.15711312,  0.31248903,
        -0.04320139, -0.08301807, -0.4610513 , -0.10252888, -0.03784092],
       [-0.03953424, -0.08268867, -0.3604463 ,  0.14048189,  0.33057037,
         0.01373108, -0.12093162, -0.38173944,  0.01771745, -0.07451382],
       [-0.05658644,  0.0519563 , -0.20794927,  0.10203589,  0.2135886 ,
         0.14241108, -0.04007911, -0.26177728, -0.08082938,  0.00216334],
       [-0.06207625, -0.01838757, -0.21708131,  0.10756977,  0.25599915,
         0.03101911,  0.05593228, -0.25550944, -0.11642678,  0.09014311],
       [ 0.05197014,  0.03873106, -0.1469059 ,  0.08044868,  0.12293777,
        -0.00388163,  0.00324975, -0.08145286, -0.12639561, -0.03596487],
       [-0.10676757, -0.05767517, -0.20481907,  0.14739943,  0.17379019,
        -0.08260865, -0.09114882, -0.38688654, -0.1448748 ,  0.03397277],
       [-0.03770879,  0.04663504, -0.30894646,  0.05933709,  0.09536786,
         0.1006383 ,  0.00984312, -0.3204393 , -0.01170056, -0.03391666],
       [ 0.0231554 ,  0.12106506, -0.255493  ,  0.04387057,  0.12491666,
         0.03297757, -0.03934925, -0.17047551,  0.00603533,  0.02295396],
       [-0.0137163 , -0.08226999, -0.3219023 ,  0.1111999 ,  0.15005693,
        -0.10358538, -0.04351711, -0.24015021, -0.08079101,  0.01281704],
       [ 0.08698535, -0.17155564, -0.19832517, -0.0417797 ,  0.24460419,
        -0.00698967,  0.08663791, -0.20004068,  0.02847612,  0.12739052],
       [ 0.0248102 , -0.07629397, -0.10130948,  0.00225735,  0.14270194,
         0.01750292,  0.03144339, -0.1429488 , -0.02819812,  0.24307509],
       [-0.06557162, -0.06485987, -0.36512223,  0.18774748,  0.25643086,
         0.0340823 , -0.01398754, -0.19010906, -0.07261477,  0.05117159],
       [ 0.04187369,  0.0132397 , -0.16233045,  0.10300563,  0.06598518,
         0.05728842, -0.02450454, -0.22889516, -0.03530695,  0.08300389],
       [ 0.15359762, -0.06493542, -0.22839671,  0.05915322,  0.26544052,
         0.15312935, -0.05132065, -0.34682024, -0.0181414 ,  0.08866596],
       [-0.06705338, -0.05590982, -0.21037713,  0.05252159,  0.22411834,
         0.06072947, -0.01180699, -0.31283215, -0.06644081, -0.02687445],
       [-0.01673558, -0.04322004, -0.22221681,  0.11640421,  0.27585298,
        -0.00789917, -0.03705985, -0.12847525, -0.14132528, -0.01258589],
       [ 0.05363014, -0.11879475, -0.08204994,  0.16474688,  0.09248446,
        -0.09719495, -0.07723137, -0.23136492, -0.05618468,  0.10164495],
       [-0.02539362, -0.14454898, -0.32296312,  0.2053542 ,  0.18563472,
        -0.0445538 , -0.13633929, -0.12712947, -0.06732591,  0.05459897],
       [-0.02403368, -0.09293792, -0.22012895,  0.09356467,  0.3415923 ,
        -0.09844425, -0.04539915, -0.28688133, -0.14435257,  0.05483858],
       [ 0.03492264,  0.04167182, -0.08564096,  0.01466741,  0.14968738,
         0.01946784, -0.04962645, -0.09357765, -0.03180797,  0.03431095],
       [ 0.04553585, -0.06386177, -0.159064  ,  0.09195592,  0.20032357,
         0.05248308,  0.05274323, -0.09328806, -0.02849531,  0.10636853],
       [-0.08788846, -0.05706687, -0.27519208,  0.12941426,  0.1730625 ,
         0.00562337,  0.03862702, -0.3364083 ,  0.01087172,  0.03377784],
       [-0.08110045, -0.06666276, -0.34764278,  0.25369477,  0.26242447,
         0.03672977,  0.07488421, -0.11382174,  0.03446682,  0.20799701],
       [-0.02429771, -0.0130821 , -0.28549588,  0.09956603,  0.19093114,
         0.09172641, -0.01084431, -0.26826024, -0.09550276, -0.09001306],
       [-0.0405377 ,  0.02302578, -0.16092977,  0.12650998,  0.10584372,
         0.0598565 ,  0.0370068 , -0.13375495, -0.05769489,  0.04597083],
       [-0.08379065, -0.12666067, -0.23740488,  0.08539408,  0.19100066,
        -0.19001569, -0.03504099, -0.2954648 , -0.00778607, -0.10035929],
       [-0.06841633, -0.02935523, -0.27325606,  0.07019119,  0.13153824,
         0.03444952, -0.07040955, -0.16061744, -0.05776489, -0.02386798],
       [ 0.02282005, -0.03760834, -0.17803052,  0.09008945,  0.15709753,
        -0.02815568, -0.01385967, -0.2636196 , -0.06011615, -0.04417434],
       [ 0.05103182, -0.0073192 , -0.2492007 ,  0.09097242,  0.2589297 ,
         0.03582668, -0.05287637, -0.1023304 , -0.10472505, -0.02360192],
       [-0.04446318, -0.00104156, -0.22680247,  0.0975772 ,  0.25874364,
         0.07281871,  0.14879908, -0.21233654, -0.11104408,  0.1596871 ],
       [-0.16542982, -0.02617702, -0.2530758 ,  0.09354755,  0.19404459,
         0.0228528 , -0.03458656, -0.3274249 , -0.08492248,  0.07104953],
       [-0.04432368, -0.01551367, -0.30958706,  0.08279304,  0.15877493,
         0.14097705,  0.0056034 , -0.2121813 , -0.10417398,  0.13372038],
       [ 0.00872401,  0.02290398, -0.18306321,  0.11926699,  0.0969364 ,
        -0.04007095,  0.01660407, -0.28434896, -0.15929542,  0.01083255],
       [ 0.07433248, -0.14991361, -0.2220522 ,  0.00625274,  0.39078072,
         0.03646233,  0.10941336, -0.20384778, -0.02929106,  0.03544597],
       [-0.00069001, -0.0680518 , -0.11302898,  0.11793397,  0.11893341,
        -0.05947986, -0.02543334, -0.24527295, -0.09240474, -0.00762735],
       [ 0.01683525,  0.03738175, -0.18935157,  0.07978748,  0.23876491,
         0.15589894, -0.00638897, -0.25770593, -0.11232982, -0.0446422 ],
       [-0.01690136, -0.19515185, -0.2338915 , -0.00964288,  0.17318843,
        -0.02175554,  0.07482283, -0.19234088, -0.0229656 ,  0.11406161],
       [-0.00661898,  0.00870193, -0.11167589,  0.15103012,  0.06432639,
        -0.12180559,  0.04999296, -0.2667799 , -0.17659347, -0.04285187],
       [-0.01717829,  0.02375691, -0.14970137,  0.1191919 ,  0.10172842,
        -0.07352136,  0.02696884, -0.11598936, -0.1331213 , -0.00928868],
       [-0.05850236,  0.03356444, -0.24372646,  0.14034908,  0.22228894,
         0.04799255, -0.01023421, -0.23915118, -0.07773915,  0.01665494],
       [-0.04828071, -0.00198432, -0.21945187,  0.14940068,  0.26243302,
         0.04732714, -0.03919668, -0.3767312 , -0.04807761,  0.04837478],
       [ 0.08090632,  0.02816604, -0.31061617,  0.04813545,  0.17886776,
         0.10947818,  0.0324835 , -0.22861008, -0.01619428, -0.00963937],
       [ 0.01237603, -0.07633115, -0.20681188,  0.08626392,  0.16251579,
         0.05692254,  0.00641025, -0.027444  ,  0.05301347,  0.00296039],
       [-0.03114549, -0.03946134, -0.20575103,  0.158873  ,  0.19106835,
        -0.00628418, -0.06812906, -0.29752672, -0.12863883,  0.00519179],
       [-0.02839492,  0.00197193, -0.38123846,  0.12928526,  0.4360217 ,
         0.06745887, -0.01924693, -0.3610945 ,  0.02880143,  0.00938179],
       [-0.10277586,  0.01430387, -0.24793717, -0.02120358,  0.20257095,
         0.10856566,  0.08017994, -0.21743834,  0.02736677,  0.01270235],
       [ 0.00209297, -0.04658009, -0.10872659,  0.00873713,  0.12002683,
        -0.01763269,  0.00062436, -0.07574805,  0.00423002,  0.09696378],
       [-0.0030484 ,  0.00373926, -0.20884912,  0.03331832,  0.37477142,
         0.14008212,  0.031428  , -0.40348598, -0.02555457,  0.05203115],
       [ 0.06917666, -0.07515088, -0.15344585,  0.08451273,  0.16555418,
        -0.00663652, -0.03506049, -0.19360425, -0.01485892, -0.1411201 ],
       [ 0.08957651, -0.0336723 , -0.16066113,  0.09386282,  0.21388392,
        -0.01653587, -0.02893457, -0.04395334, -0.03723653,  0.07710503]],
      dtype=float32)>}
2021-10-26 01:27:16.715879: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

También puede cargar y hacer inferencias de manera distribuida:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2021-10-26 01:27:16.888897: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Llamar a la función restaurada es solo un pase hacia adelante en el modelo guardado (predecir). ¿Qué sucede si desea continuar entrenando la función cargada? ¿O incrustar la función cargada en un modelo más grande? Una práctica común es envolver este objeto cargado en una capa de Keras para lograr esto. Por suerte, TF Hub tiene hub.KerasLayer para este fin, que se muestra aquí:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2021-10-26 01:27:18.637232: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.2057 - sparse_categorical_accuracy: 0.9392
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0688 - sparse_categorical_accuracy: 0.9802

Como se puede ver, hub.KerasLayer envuelve la parte posterior resultado cargado desde tf.saved_model.load() en una capa Keras que se puede utilizar para construir otro modelo. Esto es muy útil para el aprendizaje por transferencia.

¿Qué API debo utilizar?

Para el ahorro, si se está trabajando con un modelo Keras, casi siempre se recomienda el uso de la Keras model.save() de la API. Si lo que está guardando no es un modelo de Keras, entonces la API de nivel inferior es su única opción.

Para la carga, la API que utilice depende de lo que desee obtener de la API de carga. Si no pueden (o no quieren a) obtener un modelo Keras, a continuación, utilizar tf.saved_model.load() . De lo contrario, el uso tf.keras.models.load_model() . Tenga en cuenta que puede recuperar un modelo de Keras solo si guardó un modelo de Keras.

Es posible mezclar y combinar las API. Puede guardar un modelo Keras con model.save , y cargar un modelo no Keras con el bajo nivel de la API, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Guardar / cargar desde dispositivo local

Al guardar y carga desde un dispositivo de io local mientras se ejecuta de forma remota, por ejemplo usando un TPU nube, la opción experimental_io_device debe ser utilizado para configurar el dispositivo io a localhost.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Advertencias

Un caso especial es cuando tiene un modelo de Keras que no tiene entradas bien definidas. Por ejemplo, un modelo secuencial se puede crear sin ningún tipo de formas de entrada ( Sequential([Dense(3), ...] ). Modelos con subclases también no tienen entradas bien definidos después de la inicialización. En este caso, usted debe pegarse con el API de nivel inferior tanto para guardar como para cargar, de lo contrario, obtendrá un error.

Para comprobar si el modelo tiene entradas bien definidos, sólo comprobar si model.inputs es None . Si no es None , está todo bien. Formas de entrada se definen automáticamente cuando el modelo se utiliza en .fit , .evaluate , .predict , o cuando se llama el modelo ( model(inputs) ).

Aquí hay un ejemplo:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa4f68ee5d0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa4f68ee5d0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.Dense object at 0x7fa4f68ee490>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.Dense object at 0x7fa4f68ee490>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets