Salve e carregue um modelo usando uma estratégia de distribuição

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Visão geral

É comum salvar e carregar um modelo durante o treinamento. Existem dois conjuntos de APIs para salvar e carregar um modelo keras: uma API de alto nível e uma API de baixo nível. Este tutorial demonstra como você pode usar as APIs SavedModel ao usar tf.distribute.Strategy . Para saber mais sobre SavedModel e serialização em geral, leia o guia de modelo salvo e o guia de serialização de modelo Keras . Vamos começar com um exemplo simples:

Dependências de importação:

import tensorflow_datasets as tfds

import tensorflow as tf

Prepare os dados e o modelo usando tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Treine o modelo:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 10s 4ms/step - loss: 0.2033 - sparse_categorical_accuracy: 0.9408
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0644 - sparse_categorical_accuracy: 0.9812
<tensorflow.python.keras.callbacks.History at 0x7f07cc109790>

Salve e carregue o modelo

Agora que você tem um modelo simples para trabalhar, vamos dar uma olhada nas APIs de salvamento / carregamento. Existem dois conjuntos de APIs disponíveis:

As APIs Keras

Aqui está um exemplo de como salvar e carregar um modelo com as APIs Keras:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Restaure o modelo sem tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0482 - sparse_categorical_accuracy: 0.9849
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9896
<tensorflow.python.keras.callbacks.History at 0x7f08e8347e90>

Depois de restaurar o modelo, você pode continuar treinando nele, mesmo sem precisar chamar compile() novamente, uma vez que ele já está compilado antes de salvar. O modelo é salvo no formato de SavedModel padrão SavedModel do SavedModel . Para obter mais informações, consulte o guia para o formato saved_model .

Agora, para carregar o modelo e treiná-lo usando um tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9856
Epoch 2/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0337 - sparse_categorical_accuracy: 0.9898

Como você pode ver, o carregamento funciona conforme o esperado com tf.distribute.Strategy . A estratégia usada aqui não precisa ser a mesma estratégia usada antes de salvar.

As APIs tf.saved_model

Agora, vamos dar uma olhada nas APIs de nível inferior. Salvar o modelo é semelhante à API keras:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

O carregamento pode ser feito com tf.saved_model.load() . No entanto, como é uma API de nível inferior (e, portanto, tem uma gama mais ampla de casos de uso), ela não retorna um modelo Keras. Em vez disso, ele retorna um objeto que contém funções que podem ser usadas para fazer inferência. Por exemplo:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

O objeto carregado pode conter várias funções, cada uma associada a uma chave. O "serving_default" é a chave padrão para a função de inferência com um modelo Keras salvo. Para fazer uma inferência com esta função:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-3.87175977e-02, -1.61857940e-02,  4.99733090e-02,
         1.36069462e-01,  1.08132266e-01,  7.42500275e-02,
         7.73677081e-02, -4.71051671e-02,  6.41952083e-02,
         1.25855744e-01],
       [-6.54747933e-02,  2.14797929e-01, -8.99797529e-02,
        -7.83173367e-04,  8.92944187e-02,  1.73829526e-01,
         1.60048857e-01,  6.56097680e-02,  9.97588038e-03,
         1.09826915e-01],
       [-1.54433712e-01,  8.42373446e-03,  5.25225215e-02,
         4.81541678e-02,  2.23428048e-02,  7.69063532e-02,
        -3.99671011e-02, -4.56253998e-02,  5.28362878e-02,
         1.42020926e-01],
       [-9.43576694e-02, -9.26742554e-02,  7.42801651e-02,
         7.15272427e-02,  4.92905118e-02,  6.36474639e-02,
        -4.97807935e-03, -8.52353871e-02,  4.42789420e-02,
         1.56496197e-01],
       [-2.41690576e-02, -1.11253485e-01,  2.60174796e-02,
        -2.53655966e-02, -3.06888595e-02,  7.04614669e-02,
        -2.01677606e-02, -8.55238289e-02,  4.99680266e-02,
         1.76511839e-01],
       [-4.30272967e-02, -2.32194141e-02, -4.67026755e-02,
         5.49829267e-02,  3.00423354e-02,  1.36088327e-01,
         2.45837178e-02, -3.89227420e-02,  4.13372666e-02,
         7.04330802e-02],
       [-7.81589746e-03,  8.88681635e-02,  1.90377012e-02,
         2.47566961e-02,  6.73517808e-02,  9.59524438e-02,
         9.72996876e-02, -6.45323768e-02, -5.78478463e-02,
         1.67243421e-01],
       [ 2.59691626e-02,  2.44912747e-02,  2.06067227e-03,
         1.70634389e-02,  2.18513757e-02,  1.37382165e-01,
         2.51548998e-02,  9.26546752e-03,  9.15516615e-02,
         1.02677807e-01],
       [-1.27567410e-01,  2.68350840e-02,  1.16995871e-02,
         9.84568521e-02,  5.41514680e-02,  1.51651859e-01,
         4.64795753e-02, -2.91735604e-02,  1.25106841e-01,
         1.71144456e-01],
       [-8.05673450e-02,  2.94806808e-02, -1.10152550e-01,
        -7.07100704e-02, -1.24747120e-02,  1.40526205e-01,
        -1.23536028e-02, -5.82108870e-02,  7.77795464e-02,
        -2.34171711e-02],
       [ 6.20728806e-02,  2.07295828e-02, -6.43237084e-02,
        -2.05786452e-02,  4.67800628e-03,  8.61115605e-02,
         4.20134440e-02, -4.34638187e-02,  3.63051891e-03,
         7.10046589e-02],
       [-6.08086586e-04,  3.42310257e-02, -8.34690034e-02,
         9.08722728e-03,  1.05238423e-01,  1.47501156e-01,
         1.42161340e-01,  6.30808845e-02, -5.24729490e-04,
         1.36735335e-01],
       [-1.91496432e-01, -3.44831869e-03, -1.86069421e-02,
        -5.05818650e-02,  6.41704351e-03,  4.03927974e-02,
         2.21725628e-02, -1.62169725e-01,  1.76057145e-02,
         1.85961321e-01],
       [-2.49194503e-02, -2.50838101e-02, -3.03340796e-02,
         2.27186158e-02, -3.09922658e-02,  1.24549076e-01,
        -2.09298395e-02, -3.32995206e-02,  5.32348230e-02,
         1.50542200e-01],
       [-4.28871512e-02, -9.22830626e-02,  2.90450491e-02,
         1.12102680e-01,  3.04660778e-02,  7.70893693e-02,
         6.90545812e-02, -3.36272307e-02,  5.62672317e-02,
         2.48093754e-01],
       [-1.07569546e-02, -7.48349279e-02,  2.02673525e-02,
         8.59332830e-02,  1.17068969e-01,  7.48098046e-02,
         6.08244948e-02,  1.01280659e-02,  1.91551913e-02,
         2.11404741e-01],
       [ 7.62790143e-02,  2.29362398e-03, -5.91743030e-02,
         1.31748244e-02,  3.23122516e-02,  7.41088167e-02,
        -6.43679202e-02, -4.93042730e-03,  7.44977593e-02,
         1.33817106e-01],
       [-1.31239325e-01, -1.13535374e-02, -4.06663343e-02,
        -1.28620975e-02, -5.99850900e-04,  1.77248433e-01,
        -1.11874640e-02, -1.12869248e-01,  2.11616959e-02,
         1.17670909e-01],
       [-7.98342153e-02,  1.17193013e-02, -5.84214628e-02,
         4.45103496e-02, -2.35349871e-03,  9.78477970e-02,
        -1.31990746e-01, -9.62637737e-02, -4.68819216e-03,
         4.20698449e-02],
       [ 2.33993679e-02, -6.48612902e-02,  3.57650109e-02,
         4.70414013e-02, -9.87344980e-03,  8.05342346e-02,
         8.79075974e-02, -3.31879668e-02,  4.50132787e-03,
         1.80063233e-01],
       [-4.68377843e-02, -4.20477986e-03, -6.34167045e-02,
         1.02471206e-02,  7.62451440e-02,  4.69596758e-02,
         1.50972791e-02,  2.54708696e-02,  6.22914918e-02,
        -7.17046950e-03],
       [-7.52374083e-02, -5.75587116e-02,  5.98922297e-02,
        -2.73536984e-02, -9.00403410e-03,  1.24273732e-01,
         3.02318707e-02, -8.23312476e-02,  4.71908152e-02,
         6.14355356e-02],
       [-1.15195602e-01, -1.64229050e-02,  1.29237305e-03,
        -1.09596789e-01,  3.80243734e-02,  1.29902706e-01,
         2.34013163e-02, -1.05847850e-01,  5.61432168e-02,
         6.71504661e-02],
       [-6.54574633e-02,  1.53685883e-01, -9.75284129e-02,
         7.12753274e-03,  8.82450342e-02,  1.34998679e-01,
         1.54644206e-01,  6.79479688e-02, -1.72668919e-02,
        -9.18642431e-03],
       [-2.10184306e-01,  1.36045277e-01, -8.87301117e-02,
        -1.74187630e-01,  4.86502573e-02,  2.76032418e-01,
         6.05597720e-02, -7.86263272e-02,  1.03168473e-01,
        -1.37950957e-01],
       [-1.14111096e-01,  2.64648199e-02,  3.91238183e-02,
         1.27296187e-02,  1.35679528e-01,  1.77840948e-01,
        -1.35071129e-02,  2.69818418e-02,  3.42153087e-02,
         8.82407650e-02],
       [-3.01946551e-02, -2.40282975e-02, -6.38710558e-02,
        -3.98451164e-02, -2.92418208e-02,  1.03236027e-01,
        -6.42294213e-02, -6.60998672e-02,  8.01952481e-02,
         4.51336950e-02],
       [-7.84322768e-02,  3.11397053e-02, -5.89536875e-02,
         5.54301403e-03,  4.19510715e-02,  1.42857507e-01,
         3.19332965e-02,  8.02233815e-03,  4.79588211e-02,
         3.57780866e-02],
       [ 4.59235013e-02, -5.12123033e-02,  7.95385391e-02,
        -2.55374219e-02,  4.68068533e-02,  4.43704054e-02,
         7.18893707e-02,  1.64316930e-02,  1.46261796e-01,
         1.57745421e-01],
       [-7.41346031e-02, -2.22236216e-02,  5.15953489e-02,
         3.52352336e-02,  3.60111147e-02,  1.15427203e-01,
         2.57897153e-02, -1.20676845e-01,  5.01920506e-02,
         5.08756042e-02],
       [-8.63073617e-02,  3.17134708e-03,  7.40556046e-03,
        -4.21377048e-02, -1.40033364e-02,  1.12993240e-01,
         3.54410820e-02, -6.92673773e-03, -1.53519716e-02,
         1.26333967e-01],
       [-1.94949329e-01,  3.92537005e-02, -3.54437791e-02,
         2.31000036e-03,  3.79432067e-02,  1.76356524e-01,
        -2.87505463e-02, -1.43955618e-01,  5.22137731e-02,
         1.74010634e-01],
       [-7.06951618e-02,  4.39597517e-02, -6.54652417e-02,
         8.82245004e-02,  2.06859671e-02,  2.03700036e-01,
         4.97581288e-02, -9.99335721e-02,  1.11235633e-01,
         1.77167300e-02],
       [-3.53763551e-02,  1.12260692e-02, -1.11665837e-02,
         2.38924790e-02, -2.76037231e-02,  5.51242232e-02,
         1.64862610e-02, -7.48966262e-02,  8.12724680e-02,
         1.43957604e-03],
       [-6.88709617e-02,  8.31975043e-02, -1.29381031e-01,
         3.24329957e-02,  7.93049186e-02,  1.07140720e-01,
        -1.81627274e-03, -6.15758151e-02, -4.92713787e-03,
         1.36734053e-01],
       [-1.40681088e-01, -3.15807723e-02,  2.60454044e-02,
        -4.27230299e-02,  9.93238837e-02,  8.53562057e-02,
        -5.72501905e-02, -4.54951562e-02,  5.40010631e-02,
         1.60534859e-01],
       [-9.79049355e-02,  6.72584400e-03,  8.45115632e-03,
         6.96118921e-04,  6.83328062e-02,  7.55404532e-02,
         4.08902764e-02, -4.68667373e-02, -1.57676004e-02,
         1.63363129e-01],
       [ 8.52064192e-02, -4.55360785e-02, -1.45583451e-02,
         3.04546617e-02,  8.79652798e-04,  4.33428399e-02,
        -1.49330217e-02, -8.25205147e-02,  6.79182187e-02,
         1.38463050e-01],
       [-6.81671500e-03,  2.08939444e-02, -4.36149202e-02,
         4.71097752e-02, -4.12514098e-02,  7.67979622e-02,
        -3.79576795e-02, -5.68123832e-02,  8.52056891e-02,
         8.15685652e-03],
       [-3.65775973e-02, -3.06402445e-02,  6.65950254e-02,
         4.47403751e-02,  1.30895078e-01,  1.50094643e-01,
        -7.96291754e-02, -2.53046192e-02,  1.07343026e-01,
         1.23134412e-01],
       [-7.62542933e-02,  8.05391073e-02, -4.92790192e-02,
         9.18199718e-02, -2.56493650e-02,  1.11088946e-01,
         1.23991318e-01, -7.41875842e-02,  1.47743657e-01,
         1.12662911e-01],
       [-1.30924150e-01, -4.17556763e-02,  9.19794068e-02,
         1.42487824e-01,  5.94666004e-02,  8.11570883e-02,
         8.41576308e-02, -9.05135348e-02,  6.55059814e-02,
         1.40432447e-01],
       [-1.17066503e-02,  2.23872401e-02,  8.04524869e-02,
         4.16423976e-02,  2.44700648e-02,  1.15135796e-01,
        -4.52195890e-02, -3.85435522e-02,  5.72163910e-02,
        -1.70241687e-02],
       [-1.59098431e-01, -2.48743258e-02,  1.15704350e-03,
        -1.94135085e-02,  7.07440227e-02,  1.02726325e-01,
         7.48141706e-02, -4.92161885e-02, -4.82953712e-03,
         7.76504800e-02],
       [-5.23527563e-02, -6.95675313e-02,  6.23273998e-02,
         2.89142895e-02,  5.82050942e-02,  3.41961756e-02,
         5.93537465e-02, -4.88524139e-03,  6.03169389e-02,
         1.83362126e-01],
       [-1.74306586e-01,  8.50927830e-03,  5.59768602e-02,
         6.93600252e-03,  1.01455852e-01,  1.91212326e-01,
         1.01702012e-01, -4.06461619e-02,  7.65661225e-02,
         5.99906780e-02],
       [-1.77252740e-01, -8.87014568e-02,  2.68679634e-02,
        -9.68291797e-03,  3.35638374e-02,  5.05711734e-02,
        -8.12163353e-02, -1.41850561e-01,  1.26373529e-01,
         1.44208223e-01],
       [ 8.87322426e-03,  4.11978438e-02, -1.41141713e-02,
         4.23451513e-02,  1.00925714e-01,  1.82571277e-01,
         1.03720605e-01,  7.12942332e-03,  8.42842646e-03,
         1.58599243e-01],
       [-2.13406682e-02, -1.75910015e-02, -5.12900352e-02,
         2.48414166e-02,  1.27789266e-02,  1.57729238e-01,
        -9.69544426e-03, -8.99763852e-02,  1.03400812e-01,
         1.43311873e-01],
       [-8.07911009e-02, -7.29364436e-03,  8.38935375e-03,
        -2.07424574e-02,  5.76308332e-02,  1.60505801e-01,
         1.78787038e-02, -3.16766389e-02,  9.31937695e-02,
        -1.85308605e-02],
       [-9.09739435e-02,  9.12447274e-02, -2.94412468e-02,
        -4.54112887e-05,  5.61082549e-02,  1.87468201e-01,
         1.19883358e-01, -1.82710923e-02,  6.45193905e-02,
         2.49637775e-02],
       [-1.29598469e-01, -3.25308368e-03, -4.12912332e-02,
        -5.08361273e-02, -1.04823690e-02,  1.82079896e-02,
        -6.16563670e-02, -7.23700672e-02,  1.43699115e-02,
         1.14841349e-01],
       [-6.54844195e-02,  1.32181123e-02, -1.93353333e-02,
         3.21756229e-02,  1.93038825e-02,  6.58847988e-02,
        -8.00077915e-02, -8.77422094e-02,  6.08410016e-02,
         7.79580325e-02],
       [-1.41706705e-01, -2.03647800e-02, -7.27154315e-03,
         6.24460652e-02,  4.42662910e-02,  3.27514187e-02,
        -1.90984663e-02, -2.80730426e-02,  6.94331005e-02,
         1.08589470e-01],
       [-1.47680700e-01,  1.48287900e-02,  4.51873392e-02,
         9.84134525e-03, -3.12292501e-02,  7.53606334e-02,
        -3.28539237e-02, -8.99436623e-02,  9.34160203e-02,
         1.19600557e-01],
       [-1.06706396e-01, -7.91918635e-02, -3.81190814e-02,
         7.73421675e-03,  5.89197949e-02,  1.48285478e-01,
         7.33993948e-04, -7.37952963e-02,  7.03667626e-02,
         1.29369870e-02],
       [-1.18764386e-01,  6.67049065e-02, -1.84375793e-03,
         9.65175256e-02,  2.87358947e-02,  1.37036711e-01,
        -1.75398681e-02, -3.47253568e-02, -1.17472522e-02,
         1.64805949e-01],
       [-2.96032578e-02,  7.04797730e-02, -5.30908853e-02,
         3.15117575e-02,  1.10758804e-02,  1.52998209e-01,
        -4.00629006e-02, -8.53683874e-02,  8.19639489e-02,
         8.44380781e-02],
       [ 5.15967757e-02, -7.42607191e-03, -5.11791557e-03,
         6.67313561e-02,  5.96174449e-02, -7.76399858e-03,
         1.19835325e-01,  2.59960741e-02,  3.64723615e-02,
         2.43333012e-01],
       [-1.81520402e-01, -2.21860819e-02, -9.37249511e-04,
        -4.36494425e-02,  1.48944080e-01,  1.20623425e-01,
         9.27821100e-02,  8.10635090e-03,  1.05029367e-01,
         5.49212396e-02],
       [-1.62024647e-01, -7.55626708e-03,  8.90679806e-02,
         1.09557875e-01,  4.00451683e-02,  5.88794537e-02,
        -5.02957255e-02, -9.45830047e-02,  6.43425062e-02,
         6.01378679e-02],
       [ 2.44281888e-02, -2.44084001e-03,  7.41914511e-02,
         1.29617304e-01,  7.07155839e-03,  6.66829422e-02,
         2.52056420e-02, -9.07487422e-02,  5.63489906e-02,
         1.43779904e-01],
       [ 4.61201221e-02, -6.86585605e-02, -7.71630462e-03,
        -2.53528468e-02,  1.99609250e-02,  9.59918946e-02,
         1.04020424e-02, -5.57698309e-02,  2.24557221e-02,
         1.73324734e-01],
       [-8.96764472e-02, -2.28214562e-02,  1.30368788e-02,
         6.01692796e-02,  1.63236298e-02,  9.95141268e-02,
        -1.29274517e-01, -2.46923715e-02,  3.26206349e-02,
         8.57910439e-02]], dtype=float32)>}

Você também pode carregar e fazer inferências de maneira distribuída:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Chamar a função restaurada é apenas um passe para frente no modelo salvo (previsão). E se você quiser continuar treinando a função carregada? Ou incorporar a função carregada em um modelo maior? Uma prática comum é envolver esse objeto carregado em uma camada Keras para fazer isso. Felizmente, o TF Hub tem hub.KerasLayer para essa finalidade, mostrado aqui:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.1955 - sparse_categorical_accuracy: 0.9423
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0628 - sparse_categorical_accuracy: 0.9815

Como você pode ver, hub.KerasLayer envolve o resultado carregado de volta de tf.saved_model.load() em uma camada Keras que pode ser usada para construir outro modelo. Isso é muito útil para a aprendizagem por transferência.

Qual API devo usar?

Para salvar, se você estiver trabalhando com um modelo keras, quase sempre é recomendado usar a API model.save() do Keras. Se o que você está salvando não é um modelo Keras, a API de nível inferior é sua única opção.

Para carregar, qual API você usa depende do que você deseja obter da API de carregamento. Se você não puder (ou não quiser) obter um modelo Keras, use tf.saved_model.load() . Caso contrário, use tf.keras.models.load_model() . Observe que você pode obter um modelo Keras de volta somente se você salvou um modelo Keras.

É possível misturar e combinar as APIs. Você pode salvar um modelo Keras com model.save e carregar um modelo não Keras com a API de baixo nível, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Salvando / carregando do dispositivo local

Ao salvar e carregar de um dispositivo io local durante a execução remota, por exemplo, usando uma TPU em nuvem, a opção experimental_io_device deve ser usada para definir o dispositivo io como localhost.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Ressalvas

Um caso especial é quando você tem um modelo Keras que não possui entradas bem definidas. Por exemplo, um modelo Sequencial pode ser criado sem quaisquer formas de entrada ( Sequential([Dense(3), ...] ). Os modelos de subclasse também não têm entradas bem definidas após a inicialização. Neste caso, você deve ficar com o APIs de nível inferior para salvar e carregar, caso contrário, você receberá um erro.

Para verificar se seu modelo tem entradas bem definidas, basta verificar se model.inputs é None . Se não for None , você é bom. As formas de entrada são definidas automaticamente quando o modelo é usado em .fit , .evaluate , .predict ou ao chamar o modelo ( model(inputs) ).

Aqui está um exemplo:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets