¿Tengo una pregunta? Conéctese con la comunidad en el Foro de visita del foro de TensorFlow

Guarde y cargue un modelo usando una estrategia de distribución

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Descripción general

Es común guardar y cargar un modelo durante el entrenamiento. Hay dos conjuntos de API para guardar y cargar un modelo de keras: una API de alto nivel y una API de bajo nivel. Este tutorial demuestra cómo puede usar las API de SavedModel cuando usa tf.distribute.Strategy . Para obtener más información sobre el modelo guardado y la serialización en general, lea la guía del modelo guardado y la guía de serialización del modelo Keras . Comencemos con un ejemplo simple:

Importar dependencias:

import tensorflow_datasets as tfds

import tensorflow as tf

Prepare los datos y el modelo usando tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Entrena al modelo:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 10s 4ms/step - loss: 0.2033 - sparse_categorical_accuracy: 0.9408
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0644 - sparse_categorical_accuracy: 0.9812
<tensorflow.python.keras.callbacks.History at 0x7f07cc109790>

Guarde y cargue el modelo

Ahora que tiene un modelo simple con el que trabajar, echemos un vistazo a las API de guardado / carga. Hay dos conjuntos de API disponibles:

Las API de Keras

A continuación, se muestra un ejemplo de cómo guardar y cargar un modelo con las API de Keras:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Restaurar el modelo sin tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0482 - sparse_categorical_accuracy: 0.9849
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9896
<tensorflow.python.keras.callbacks.History at 0x7f08e8347e90>

Después de restaurar el modelo, puede continuar entrenando en él, incluso sin necesidad de llamar a compile() nuevamente, ya que ya está compilado antes de guardarlo. El modelo se guarda en el formato proto estándar SavedModel TensorFlow. Para obtener más información, consulte la guía de formato saved_model .

Ahora para cargar el modelo y entrenarlo usando un tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9856
Epoch 2/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0337 - sparse_categorical_accuracy: 0.9898

Como puede ver, la carga funciona como se esperaba con tf.distribute.Strategy . La estrategia utilizada aquí no tiene que ser la misma estrategia utilizada antes de guardar.

Las API tf.saved_model

Ahora echemos un vistazo a las API de nivel inferior. Guardar el modelo es similar a la API de keras:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

La carga se puede realizar con tf.saved_model.load() . Sin embargo, dado que es una API que se encuentra en el nivel inferior (y, por lo tanto, tiene una gama más amplia de casos de uso), no devuelve un modelo de Keras. En cambio, devuelve un objeto que contiene funciones que se pueden usar para hacer inferencias. Por ejemplo:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

El objeto cargado puede contener múltiples funciones, cada una asociada con una clave. El "serving_default" es la clave predeterminada para la función de inferencia con un modelo de Keras guardado. Para hacer una inferencia con esta función:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-3.87175977e-02, -1.61857940e-02,  4.99733090e-02,
         1.36069462e-01,  1.08132266e-01,  7.42500275e-02,
         7.73677081e-02, -4.71051671e-02,  6.41952083e-02,
         1.25855744e-01],
       [-6.54747933e-02,  2.14797929e-01, -8.99797529e-02,
        -7.83173367e-04,  8.92944187e-02,  1.73829526e-01,
         1.60048857e-01,  6.56097680e-02,  9.97588038e-03,
         1.09826915e-01],
       [-1.54433712e-01,  8.42373446e-03,  5.25225215e-02,
         4.81541678e-02,  2.23428048e-02,  7.69063532e-02,
        -3.99671011e-02, -4.56253998e-02,  5.28362878e-02,
         1.42020926e-01],
       [-9.43576694e-02, -9.26742554e-02,  7.42801651e-02,
         7.15272427e-02,  4.92905118e-02,  6.36474639e-02,
        -4.97807935e-03, -8.52353871e-02,  4.42789420e-02,
         1.56496197e-01],
       [-2.41690576e-02, -1.11253485e-01,  2.60174796e-02,
        -2.53655966e-02, -3.06888595e-02,  7.04614669e-02,
        -2.01677606e-02, -8.55238289e-02,  4.99680266e-02,
         1.76511839e-01],
       [-4.30272967e-02, -2.32194141e-02, -4.67026755e-02,
         5.49829267e-02,  3.00423354e-02,  1.36088327e-01,
         2.45837178e-02, -3.89227420e-02,  4.13372666e-02,
         7.04330802e-02],
       [-7.81589746e-03,  8.88681635e-02,  1.90377012e-02,
         2.47566961e-02,  6.73517808e-02,  9.59524438e-02,
         9.72996876e-02, -6.45323768e-02, -5.78478463e-02,
         1.67243421e-01],
       [ 2.59691626e-02,  2.44912747e-02,  2.06067227e-03,
         1.70634389e-02,  2.18513757e-02,  1.37382165e-01,
         2.51548998e-02,  9.26546752e-03,  9.15516615e-02,
         1.02677807e-01],
       [-1.27567410e-01,  2.68350840e-02,  1.16995871e-02,
         9.84568521e-02,  5.41514680e-02,  1.51651859e-01,
         4.64795753e-02, -2.91735604e-02,  1.25106841e-01,
         1.71144456e-01],
       [-8.05673450e-02,  2.94806808e-02, -1.10152550e-01,
        -7.07100704e-02, -1.24747120e-02,  1.40526205e-01,
        -1.23536028e-02, -5.82108870e-02,  7.77795464e-02,
        -2.34171711e-02],
       [ 6.20728806e-02,  2.07295828e-02, -6.43237084e-02,
        -2.05786452e-02,  4.67800628e-03,  8.61115605e-02,
         4.20134440e-02, -4.34638187e-02,  3.63051891e-03,
         7.10046589e-02],
       [-6.08086586e-04,  3.42310257e-02, -8.34690034e-02,
         9.08722728e-03,  1.05238423e-01,  1.47501156e-01,
         1.42161340e-01,  6.30808845e-02, -5.24729490e-04,
         1.36735335e-01],
       [-1.91496432e-01, -3.44831869e-03, -1.86069421e-02,
        -5.05818650e-02,  6.41704351e-03,  4.03927974e-02,
         2.21725628e-02, -1.62169725e-01,  1.76057145e-02,
         1.85961321e-01],
       [-2.49194503e-02, -2.50838101e-02, -3.03340796e-02,
         2.27186158e-02, -3.09922658e-02,  1.24549076e-01,
        -2.09298395e-02, -3.32995206e-02,  5.32348230e-02,
         1.50542200e-01],
       [-4.28871512e-02, -9.22830626e-02,  2.90450491e-02,
         1.12102680e-01,  3.04660778e-02,  7.70893693e-02,
         6.90545812e-02, -3.36272307e-02,  5.62672317e-02,
         2.48093754e-01],
       [-1.07569546e-02, -7.48349279e-02,  2.02673525e-02,
         8.59332830e-02,  1.17068969e-01,  7.48098046e-02,
         6.08244948e-02,  1.01280659e-02,  1.91551913e-02,
         2.11404741e-01],
       [ 7.62790143e-02,  2.29362398e-03, -5.91743030e-02,
         1.31748244e-02,  3.23122516e-02,  7.41088167e-02,
        -6.43679202e-02, -4.93042730e-03,  7.44977593e-02,
         1.33817106e-01],
       [-1.31239325e-01, -1.13535374e-02, -4.06663343e-02,
        -1.28620975e-02, -5.99850900e-04,  1.77248433e-01,
        -1.11874640e-02, -1.12869248e-01,  2.11616959e-02,
         1.17670909e-01],
       [-7.98342153e-02,  1.17193013e-02, -5.84214628e-02,
         4.45103496e-02, -2.35349871e-03,  9.78477970e-02,
        -1.31990746e-01, -9.62637737e-02, -4.68819216e-03,
         4.20698449e-02],
       [ 2.33993679e-02, -6.48612902e-02,  3.57650109e-02,
         4.70414013e-02, -9.87344980e-03,  8.05342346e-02,
         8.79075974e-02, -3.31879668e-02,  4.50132787e-03,
         1.80063233e-01],
       [-4.68377843e-02, -4.20477986e-03, -6.34167045e-02,
         1.02471206e-02,  7.62451440e-02,  4.69596758e-02,
         1.50972791e-02,  2.54708696e-02,  6.22914918e-02,
        -7.17046950e-03],
       [-7.52374083e-02, -5.75587116e-02,  5.98922297e-02,
        -2.73536984e-02, -9.00403410e-03,  1.24273732e-01,
         3.02318707e-02, -8.23312476e-02,  4.71908152e-02,
         6.14355356e-02],
       [-1.15195602e-01, -1.64229050e-02,  1.29237305e-03,
        -1.09596789e-01,  3.80243734e-02,  1.29902706e-01,
         2.34013163e-02, -1.05847850e-01,  5.61432168e-02,
         6.71504661e-02],
       [-6.54574633e-02,  1.53685883e-01, -9.75284129e-02,
         7.12753274e-03,  8.82450342e-02,  1.34998679e-01,
         1.54644206e-01,  6.79479688e-02, -1.72668919e-02,
        -9.18642431e-03],
       [-2.10184306e-01,  1.36045277e-01, -8.87301117e-02,
        -1.74187630e-01,  4.86502573e-02,  2.76032418e-01,
         6.05597720e-02, -7.86263272e-02,  1.03168473e-01,
        -1.37950957e-01],
       [-1.14111096e-01,  2.64648199e-02,  3.91238183e-02,
         1.27296187e-02,  1.35679528e-01,  1.77840948e-01,
        -1.35071129e-02,  2.69818418e-02,  3.42153087e-02,
         8.82407650e-02],
       [-3.01946551e-02, -2.40282975e-02, -6.38710558e-02,
        -3.98451164e-02, -2.92418208e-02,  1.03236027e-01,
        -6.42294213e-02, -6.60998672e-02,  8.01952481e-02,
         4.51336950e-02],
       [-7.84322768e-02,  3.11397053e-02, -5.89536875e-02,
         5.54301403e-03,  4.19510715e-02,  1.42857507e-01,
         3.19332965e-02,  8.02233815e-03,  4.79588211e-02,
         3.57780866e-02],
       [ 4.59235013e-02, -5.12123033e-02,  7.95385391e-02,
        -2.55374219e-02,  4.68068533e-02,  4.43704054e-02,
         7.18893707e-02,  1.64316930e-02,  1.46261796e-01,
         1.57745421e-01],
       [-7.41346031e-02, -2.22236216e-02,  5.15953489e-02,
         3.52352336e-02,  3.60111147e-02,  1.15427203e-01,
         2.57897153e-02, -1.20676845e-01,  5.01920506e-02,
         5.08756042e-02],
       [-8.63073617e-02,  3.17134708e-03,  7.40556046e-03,
        -4.21377048e-02, -1.40033364e-02,  1.12993240e-01,
         3.54410820e-02, -6.92673773e-03, -1.53519716e-02,
         1.26333967e-01],
       [-1.94949329e-01,  3.92537005e-02, -3.54437791e-02,
         2.31000036e-03,  3.79432067e-02,  1.76356524e-01,
        -2.87505463e-02, -1.43955618e-01,  5.22137731e-02,
         1.74010634e-01],
       [-7.06951618e-02,  4.39597517e-02, -6.54652417e-02,
         8.82245004e-02,  2.06859671e-02,  2.03700036e-01,
         4.97581288e-02, -9.99335721e-02,  1.11235633e-01,
         1.77167300e-02],
       [-3.53763551e-02,  1.12260692e-02, -1.11665837e-02,
         2.38924790e-02, -2.76037231e-02,  5.51242232e-02,
         1.64862610e-02, -7.48966262e-02,  8.12724680e-02,
         1.43957604e-03],
       [-6.88709617e-02,  8.31975043e-02, -1.29381031e-01,
         3.24329957e-02,  7.93049186e-02,  1.07140720e-01,
        -1.81627274e-03, -6.15758151e-02, -4.92713787e-03,
         1.36734053e-01],
       [-1.40681088e-01, -3.15807723e-02,  2.60454044e-02,
        -4.27230299e-02,  9.93238837e-02,  8.53562057e-02,
        -5.72501905e-02, -4.54951562e-02,  5.40010631e-02,
         1.60534859e-01],
       [-9.79049355e-02,  6.72584400e-03,  8.45115632e-03,
         6.96118921e-04,  6.83328062e-02,  7.55404532e-02,
         4.08902764e-02, -4.68667373e-02, -1.57676004e-02,
         1.63363129e-01],
       [ 8.52064192e-02, -4.55360785e-02, -1.45583451e-02,
         3.04546617e-02,  8.79652798e-04,  4.33428399e-02,
        -1.49330217e-02, -8.25205147e-02,  6.79182187e-02,
         1.38463050e-01],
       [-6.81671500e-03,  2.08939444e-02, -4.36149202e-02,
         4.71097752e-02, -4.12514098e-02,  7.67979622e-02,
        -3.79576795e-02, -5.68123832e-02,  8.52056891e-02,
         8.15685652e-03],
       [-3.65775973e-02, -3.06402445e-02,  6.65950254e-02,
         4.47403751e-02,  1.30895078e-01,  1.50094643e-01,
        -7.96291754e-02, -2.53046192e-02,  1.07343026e-01,
         1.23134412e-01],
       [-7.62542933e-02,  8.05391073e-02, -4.92790192e-02,
         9.18199718e-02, -2.56493650e-02,  1.11088946e-01,
         1.23991318e-01, -7.41875842e-02,  1.47743657e-01,
         1.12662911e-01],
       [-1.30924150e-01, -4.17556763e-02,  9.19794068e-02,
         1.42487824e-01,  5.94666004e-02,  8.11570883e-02,
         8.41576308e-02, -9.05135348e-02,  6.55059814e-02,
         1.40432447e-01],
       [-1.17066503e-02,  2.23872401e-02,  8.04524869e-02,
         4.16423976e-02,  2.44700648e-02,  1.15135796e-01,
        -4.52195890e-02, -3.85435522e-02,  5.72163910e-02,
        -1.70241687e-02],
       [-1.59098431e-01, -2.48743258e-02,  1.15704350e-03,
        -1.94135085e-02,  7.07440227e-02,  1.02726325e-01,
         7.48141706e-02, -4.92161885e-02, -4.82953712e-03,
         7.76504800e-02],
       [-5.23527563e-02, -6.95675313e-02,  6.23273998e-02,
         2.89142895e-02,  5.82050942e-02,  3.41961756e-02,
         5.93537465e-02, -4.88524139e-03,  6.03169389e-02,
         1.83362126e-01],
       [-1.74306586e-01,  8.50927830e-03,  5.59768602e-02,
         6.93600252e-03,  1.01455852e-01,  1.91212326e-01,
         1.01702012e-01, -4.06461619e-02,  7.65661225e-02,
         5.99906780e-02],
       [-1.77252740e-01, -8.87014568e-02,  2.68679634e-02,
        -9.68291797e-03,  3.35638374e-02,  5.05711734e-02,
        -8.12163353e-02, -1.41850561e-01,  1.26373529e-01,
         1.44208223e-01],
       [ 8.87322426e-03,  4.11978438e-02, -1.41141713e-02,
         4.23451513e-02,  1.00925714e-01,  1.82571277e-01,
         1.03720605e-01,  7.12942332e-03,  8.42842646e-03,
         1.58599243e-01],
       [-2.13406682e-02, -1.75910015e-02, -5.12900352e-02,
         2.48414166e-02,  1.27789266e-02,  1.57729238e-01,
        -9.69544426e-03, -8.99763852e-02,  1.03400812e-01,
         1.43311873e-01],
       [-8.07911009e-02, -7.29364436e-03,  8.38935375e-03,
        -2.07424574e-02,  5.76308332e-02,  1.60505801e-01,
         1.78787038e-02, -3.16766389e-02,  9.31937695e-02,
        -1.85308605e-02],
       [-9.09739435e-02,  9.12447274e-02, -2.94412468e-02,
        -4.54112887e-05,  5.61082549e-02,  1.87468201e-01,
         1.19883358e-01, -1.82710923e-02,  6.45193905e-02,
         2.49637775e-02],
       [-1.29598469e-01, -3.25308368e-03, -4.12912332e-02,
        -5.08361273e-02, -1.04823690e-02,  1.82079896e-02,
        -6.16563670e-02, -7.23700672e-02,  1.43699115e-02,
         1.14841349e-01],
       [-6.54844195e-02,  1.32181123e-02, -1.93353333e-02,
         3.21756229e-02,  1.93038825e-02,  6.58847988e-02,
        -8.00077915e-02, -8.77422094e-02,  6.08410016e-02,
         7.79580325e-02],
       [-1.41706705e-01, -2.03647800e-02, -7.27154315e-03,
         6.24460652e-02,  4.42662910e-02,  3.27514187e-02,
        -1.90984663e-02, -2.80730426e-02,  6.94331005e-02,
         1.08589470e-01],
       [-1.47680700e-01,  1.48287900e-02,  4.51873392e-02,
         9.84134525e-03, -3.12292501e-02,  7.53606334e-02,
        -3.28539237e-02, -8.99436623e-02,  9.34160203e-02,
         1.19600557e-01],
       [-1.06706396e-01, -7.91918635e-02, -3.81190814e-02,
         7.73421675e-03,  5.89197949e-02,  1.48285478e-01,
         7.33993948e-04, -7.37952963e-02,  7.03667626e-02,
         1.29369870e-02],
       [-1.18764386e-01,  6.67049065e-02, -1.84375793e-03,
         9.65175256e-02,  2.87358947e-02,  1.37036711e-01,
        -1.75398681e-02, -3.47253568e-02, -1.17472522e-02,
         1.64805949e-01],
       [-2.96032578e-02,  7.04797730e-02, -5.30908853e-02,
         3.15117575e-02,  1.10758804e-02,  1.52998209e-01,
        -4.00629006e-02, -8.53683874e-02,  8.19639489e-02,
         8.44380781e-02],
       [ 5.15967757e-02, -7.42607191e-03, -5.11791557e-03,
         6.67313561e-02,  5.96174449e-02, -7.76399858e-03,
         1.19835325e-01,  2.59960741e-02,  3.64723615e-02,
         2.43333012e-01],
       [-1.81520402e-01, -2.21860819e-02, -9.37249511e-04,
        -4.36494425e-02,  1.48944080e-01,  1.20623425e-01,
         9.27821100e-02,  8.10635090e-03,  1.05029367e-01,
         5.49212396e-02],
       [-1.62024647e-01, -7.55626708e-03,  8.90679806e-02,
         1.09557875e-01,  4.00451683e-02,  5.88794537e-02,
        -5.02957255e-02, -9.45830047e-02,  6.43425062e-02,
         6.01378679e-02],
       [ 2.44281888e-02, -2.44084001e-03,  7.41914511e-02,
         1.29617304e-01,  7.07155839e-03,  6.66829422e-02,
         2.52056420e-02, -9.07487422e-02,  5.63489906e-02,
         1.43779904e-01],
       [ 4.61201221e-02, -6.86585605e-02, -7.71630462e-03,
        -2.53528468e-02,  1.99609250e-02,  9.59918946e-02,
         1.04020424e-02, -5.57698309e-02,  2.24557221e-02,
         1.73324734e-01],
       [-8.96764472e-02, -2.28214562e-02,  1.30368788e-02,
         6.01692796e-02,  1.63236298e-02,  9.95141268e-02,
        -1.29274517e-01, -2.46923715e-02,  3.26206349e-02,
         8.57910439e-02]], dtype=float32)>}

También puede cargar y hacer inferencias de manera distribuida:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Llamar a la función restaurada es solo un pase hacia adelante en el modelo guardado (predecir). ¿Qué sucede si desea continuar entrenando la función cargada? ¿O incrustar la función cargada en un modelo más grande? Una práctica común es envolver este objeto cargado en una capa de Keras para lograr esto. Afortunadamente, TF Hub tiene hub.KerasLayer para este propósito, que se muestra aquí:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.1955 - sparse_categorical_accuracy: 0.9423
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0628 - sparse_categorical_accuracy: 0.9815

Como puede ver, hub.KerasLayer envuelve el resultado cargado desde tf.saved_model.load() en una capa de Keras que se puede usar para construir otro modelo. Esto es muy útil para el aprendizaje por transferencia.

¿Qué API debo utilizar?

Para guardar, si está trabajando con un modelo de Keras, casi siempre se recomienda utilizar la API model.save() Keras. Si lo que está guardando no es un modelo de Keras, entonces la API de nivel inferior es su única opción.

Para la carga, la API que utilice depende de lo que desee obtener de la API de carga. Si no puede (o no desea) obtener un modelo de Keras, utilice tf.saved_model.load() . De lo contrario, use tf.keras.models.load_model() . Tenga en cuenta que puede recuperar un modelo de Keras solo si guardó un modelo de Keras.

Es posible mezclar y combinar las API. Puede guardar un modelo de Keras con model.save y cargar un modelo que no sea de Keras con la API de bajo nivel, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Guardar / cargar desde dispositivo local

Al guardar y cargar desde un dispositivo io local mientras se ejecuta de forma remota, por ejemplo, usando una TPU en la nube, se debe usar la opción experimental_io_device para configurar el dispositivo io en localhost.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Advertencias

Un caso especial es cuando tiene un modelo de Keras que no tiene entradas bien definidas. Por ejemplo, un modelo secuencial se puede crear sin ninguna forma de entrada ( Sequential([Dense(3), ...] ). Los modelos subclasificados tampoco tienen entradas bien definidas después de la inicialización. En este caso, debe seguir con el API de nivel inferior tanto para guardar como para cargar, de lo contrario, obtendrá un error.

Para verificar si su modelo tiene entradas bien definidas, simplemente verifique si model.inputs es None . Si no es None , todo está bien. Las formas de entrada se definen automáticamente cuando el modelo se usa en .fit , .evaluate , .predict , o cuando se llama al modelo ( model(inputs) ).

Aquí hay un ejemplo:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets