Entrare in sintonia con le prime donne in ML Simposio questo Martedì 19 ottobre alle 9 PST Registrarsi

Salva e carica un modello utilizzando una strategia di distribuzione

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica il taccuino

Panoramica

È comune salvare e caricare un modello durante l'addestramento. Esistono due set di API per salvare e caricare un modello Keras: un'API di alto livello e un'API di basso livello. Questo tutorial mostra come puoi usare le API SavedModel quando usi tf.distribute.Strategy . Per informazioni su SavedModel e sulla serializzazione in generale, leggere la guida al modello salvato e la guida alla serializzazione del modello Keras . Cominciamo con un semplice esempio:

Dipendenze di importazione:

import tensorflow_datasets as tfds

import tensorflow as tf

Preparare i dati e il modello utilizzando tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Allena il modello:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 10s 4ms/step - loss: 0.2033 - sparse_categorical_accuracy: 0.9408
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0644 - sparse_categorical_accuracy: 0.9812
<tensorflow.python.keras.callbacks.History at 0x7f07cc109790>

Salva e carica il modello

Ora che hai un modello semplice con cui lavorare, diamo un'occhiata alle API di salvataggio/caricamento. Sono disponibili due set di API:

Le API Keras

Ecco un esempio di salvataggio e caricamento di un modello con le API Keras:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Ripristina il modello senza tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0482 - sparse_categorical_accuracy: 0.9849
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9896
<tensorflow.python.keras.callbacks.History at 0x7f08e8347e90>

Dopo aver ripristinato il modello, puoi continuare l'addestramento su di esso, anche senza dover chiamare nuovamente compile() , poiché è già compilato prima di salvare. Il modello viene salvato nel formato del SavedModel standard di TensorFlow. Per maggiori informazioni, fare riferimento alla guida al formato saved_model .

Ora per caricare il modello e addestrarlo utilizzando un tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9856
Epoch 2/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0337 - sparse_categorical_accuracy: 0.9898

Come puoi vedere, il caricamento funziona come previsto con tf.distribute.Strategy . La strategia utilizzata qui non deve essere la stessa strategia utilizzata prima del salvataggio.

Le API tf.saved_model

Ora diamo un'occhiata alle API di livello inferiore. Il salvataggio del modello è simile all'API di keras:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

Il caricamento può essere eseguito con tf.saved_model.load() . Tuttavia, poiché è un'API di livello inferiore (e quindi ha una gamma più ampia di casi d'uso), non restituisce un modello Keras. Invece, restituisce un oggetto che contiene funzioni che possono essere usate per fare inferenze. Per esempio:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

L'oggetto caricato può contenere più funzioni, ciascuna associata a una chiave. Il "serving_default" è la chiave predefinita per la funzione di inferenza con un modello Keras salvato. Per fare un'inferenza con questa funzione:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-3.87175977e-02, -1.61857940e-02,  4.99733090e-02,
         1.36069462e-01,  1.08132266e-01,  7.42500275e-02,
         7.73677081e-02, -4.71051671e-02,  6.41952083e-02,
         1.25855744e-01],
       [-6.54747933e-02,  2.14797929e-01, -8.99797529e-02,
        -7.83173367e-04,  8.92944187e-02,  1.73829526e-01,
         1.60048857e-01,  6.56097680e-02,  9.97588038e-03,
         1.09826915e-01],
       [-1.54433712e-01,  8.42373446e-03,  5.25225215e-02,
         4.81541678e-02,  2.23428048e-02,  7.69063532e-02,
        -3.99671011e-02, -4.56253998e-02,  5.28362878e-02,
         1.42020926e-01],
       [-9.43576694e-02, -9.26742554e-02,  7.42801651e-02,
         7.15272427e-02,  4.92905118e-02,  6.36474639e-02,
        -4.97807935e-03, -8.52353871e-02,  4.42789420e-02,
         1.56496197e-01],
       [-2.41690576e-02, -1.11253485e-01,  2.60174796e-02,
        -2.53655966e-02, -3.06888595e-02,  7.04614669e-02,
        -2.01677606e-02, -8.55238289e-02,  4.99680266e-02,
         1.76511839e-01],
       [-4.30272967e-02, -2.32194141e-02, -4.67026755e-02,
         5.49829267e-02,  3.00423354e-02,  1.36088327e-01,
         2.45837178e-02, -3.89227420e-02,  4.13372666e-02,
         7.04330802e-02],
       [-7.81589746e-03,  8.88681635e-02,  1.90377012e-02,
         2.47566961e-02,  6.73517808e-02,  9.59524438e-02,
         9.72996876e-02, -6.45323768e-02, -5.78478463e-02,
         1.67243421e-01],
       [ 2.59691626e-02,  2.44912747e-02,  2.06067227e-03,
         1.70634389e-02,  2.18513757e-02,  1.37382165e-01,
         2.51548998e-02,  9.26546752e-03,  9.15516615e-02,
         1.02677807e-01],
       [-1.27567410e-01,  2.68350840e-02,  1.16995871e-02,
         9.84568521e-02,  5.41514680e-02,  1.51651859e-01,
         4.64795753e-02, -2.91735604e-02,  1.25106841e-01,
         1.71144456e-01],
       [-8.05673450e-02,  2.94806808e-02, -1.10152550e-01,
        -7.07100704e-02, -1.24747120e-02,  1.40526205e-01,
        -1.23536028e-02, -5.82108870e-02,  7.77795464e-02,
        -2.34171711e-02],
       [ 6.20728806e-02,  2.07295828e-02, -6.43237084e-02,
        -2.05786452e-02,  4.67800628e-03,  8.61115605e-02,
         4.20134440e-02, -4.34638187e-02,  3.63051891e-03,
         7.10046589e-02],
       [-6.08086586e-04,  3.42310257e-02, -8.34690034e-02,
         9.08722728e-03,  1.05238423e-01,  1.47501156e-01,
         1.42161340e-01,  6.30808845e-02, -5.24729490e-04,
         1.36735335e-01],
       [-1.91496432e-01, -3.44831869e-03, -1.86069421e-02,
        -5.05818650e-02,  6.41704351e-03,  4.03927974e-02,
         2.21725628e-02, -1.62169725e-01,  1.76057145e-02,
         1.85961321e-01],
       [-2.49194503e-02, -2.50838101e-02, -3.03340796e-02,
         2.27186158e-02, -3.09922658e-02,  1.24549076e-01,
        -2.09298395e-02, -3.32995206e-02,  5.32348230e-02,
         1.50542200e-01],
       [-4.28871512e-02, -9.22830626e-02,  2.90450491e-02,
         1.12102680e-01,  3.04660778e-02,  7.70893693e-02,
         6.90545812e-02, -3.36272307e-02,  5.62672317e-02,
         2.48093754e-01],
       [-1.07569546e-02, -7.48349279e-02,  2.02673525e-02,
         8.59332830e-02,  1.17068969e-01,  7.48098046e-02,
         6.08244948e-02,  1.01280659e-02,  1.91551913e-02,
         2.11404741e-01],
       [ 7.62790143e-02,  2.29362398e-03, -5.91743030e-02,
         1.31748244e-02,  3.23122516e-02,  7.41088167e-02,
        -6.43679202e-02, -4.93042730e-03,  7.44977593e-02,
         1.33817106e-01],
       [-1.31239325e-01, -1.13535374e-02, -4.06663343e-02,
        -1.28620975e-02, -5.99850900e-04,  1.77248433e-01,
        -1.11874640e-02, -1.12869248e-01,  2.11616959e-02,
         1.17670909e-01],
       [-7.98342153e-02,  1.17193013e-02, -5.84214628e-02,
         4.45103496e-02, -2.35349871e-03,  9.78477970e-02,
        -1.31990746e-01, -9.62637737e-02, -4.68819216e-03,
         4.20698449e-02],
       [ 2.33993679e-02, -6.48612902e-02,  3.57650109e-02,
         4.70414013e-02, -9.87344980e-03,  8.05342346e-02,
         8.79075974e-02, -3.31879668e-02,  4.50132787e-03,
         1.80063233e-01],
       [-4.68377843e-02, -4.20477986e-03, -6.34167045e-02,
         1.02471206e-02,  7.62451440e-02,  4.69596758e-02,
         1.50972791e-02,  2.54708696e-02,  6.22914918e-02,
        -7.17046950e-03],
       [-7.52374083e-02, -5.75587116e-02,  5.98922297e-02,
        -2.73536984e-02, -9.00403410e-03,  1.24273732e-01,
         3.02318707e-02, -8.23312476e-02,  4.71908152e-02,
         6.14355356e-02],
       [-1.15195602e-01, -1.64229050e-02,  1.29237305e-03,
        -1.09596789e-01,  3.80243734e-02,  1.29902706e-01,
         2.34013163e-02, -1.05847850e-01,  5.61432168e-02,
         6.71504661e-02],
       [-6.54574633e-02,  1.53685883e-01, -9.75284129e-02,
         7.12753274e-03,  8.82450342e-02,  1.34998679e-01,
         1.54644206e-01,  6.79479688e-02, -1.72668919e-02,
        -9.18642431e-03],
       [-2.10184306e-01,  1.36045277e-01, -8.87301117e-02,
        -1.74187630e-01,  4.86502573e-02,  2.76032418e-01,
         6.05597720e-02, -7.86263272e-02,  1.03168473e-01,
        -1.37950957e-01],
       [-1.14111096e-01,  2.64648199e-02,  3.91238183e-02,
         1.27296187e-02,  1.35679528e-01,  1.77840948e-01,
        -1.35071129e-02,  2.69818418e-02,  3.42153087e-02,
         8.82407650e-02],
       [-3.01946551e-02, -2.40282975e-02, -6.38710558e-02,
        -3.98451164e-02, -2.92418208e-02,  1.03236027e-01,
        -6.42294213e-02, -6.60998672e-02,  8.01952481e-02,
         4.51336950e-02],
       [-7.84322768e-02,  3.11397053e-02, -5.89536875e-02,
         5.54301403e-03,  4.19510715e-02,  1.42857507e-01,
         3.19332965e-02,  8.02233815e-03,  4.79588211e-02,
         3.57780866e-02],
       [ 4.59235013e-02, -5.12123033e-02,  7.95385391e-02,
        -2.55374219e-02,  4.68068533e-02,  4.43704054e-02,
         7.18893707e-02,  1.64316930e-02,  1.46261796e-01,
         1.57745421e-01],
       [-7.41346031e-02, -2.22236216e-02,  5.15953489e-02,
         3.52352336e-02,  3.60111147e-02,  1.15427203e-01,
         2.57897153e-02, -1.20676845e-01,  5.01920506e-02,
         5.08756042e-02],
       [-8.63073617e-02,  3.17134708e-03,  7.40556046e-03,
        -4.21377048e-02, -1.40033364e-02,  1.12993240e-01,
         3.54410820e-02, -6.92673773e-03, -1.53519716e-02,
         1.26333967e-01],
       [-1.94949329e-01,  3.92537005e-02, -3.54437791e-02,
         2.31000036e-03,  3.79432067e-02,  1.76356524e-01,
        -2.87505463e-02, -1.43955618e-01,  5.22137731e-02,
         1.74010634e-01],
       [-7.06951618e-02,  4.39597517e-02, -6.54652417e-02,
         8.82245004e-02,  2.06859671e-02,  2.03700036e-01,
         4.97581288e-02, -9.99335721e-02,  1.11235633e-01,
         1.77167300e-02],
       [-3.53763551e-02,  1.12260692e-02, -1.11665837e-02,
         2.38924790e-02, -2.76037231e-02,  5.51242232e-02,
         1.64862610e-02, -7.48966262e-02,  8.12724680e-02,
         1.43957604e-03],
       [-6.88709617e-02,  8.31975043e-02, -1.29381031e-01,
         3.24329957e-02,  7.93049186e-02,  1.07140720e-01,
        -1.81627274e-03, -6.15758151e-02, -4.92713787e-03,
         1.36734053e-01],
       [-1.40681088e-01, -3.15807723e-02,  2.60454044e-02,
        -4.27230299e-02,  9.93238837e-02,  8.53562057e-02,
        -5.72501905e-02, -4.54951562e-02,  5.40010631e-02,
         1.60534859e-01],
       [-9.79049355e-02,  6.72584400e-03,  8.45115632e-03,
         6.96118921e-04,  6.83328062e-02,  7.55404532e-02,
         4.08902764e-02, -4.68667373e-02, -1.57676004e-02,
         1.63363129e-01],
       [ 8.52064192e-02, -4.55360785e-02, -1.45583451e-02,
         3.04546617e-02,  8.79652798e-04,  4.33428399e-02,
        -1.49330217e-02, -8.25205147e-02,  6.79182187e-02,
         1.38463050e-01],
       [-6.81671500e-03,  2.08939444e-02, -4.36149202e-02,
         4.71097752e-02, -4.12514098e-02,  7.67979622e-02,
        -3.79576795e-02, -5.68123832e-02,  8.52056891e-02,
         8.15685652e-03],
       [-3.65775973e-02, -3.06402445e-02,  6.65950254e-02,
         4.47403751e-02,  1.30895078e-01,  1.50094643e-01,
        -7.96291754e-02, -2.53046192e-02,  1.07343026e-01,
         1.23134412e-01],
       [-7.62542933e-02,  8.05391073e-02, -4.92790192e-02,
         9.18199718e-02, -2.56493650e-02,  1.11088946e-01,
         1.23991318e-01, -7.41875842e-02,  1.47743657e-01,
         1.12662911e-01],
       [-1.30924150e-01, -4.17556763e-02,  9.19794068e-02,
         1.42487824e-01,  5.94666004e-02,  8.11570883e-02,
         8.41576308e-02, -9.05135348e-02,  6.55059814e-02,
         1.40432447e-01],
       [-1.17066503e-02,  2.23872401e-02,  8.04524869e-02,
         4.16423976e-02,  2.44700648e-02,  1.15135796e-01,
        -4.52195890e-02, -3.85435522e-02,  5.72163910e-02,
        -1.70241687e-02],
       [-1.59098431e-01, -2.48743258e-02,  1.15704350e-03,
        -1.94135085e-02,  7.07440227e-02,  1.02726325e-01,
         7.48141706e-02, -4.92161885e-02, -4.82953712e-03,
         7.76504800e-02],
       [-5.23527563e-02, -6.95675313e-02,  6.23273998e-02,
         2.89142895e-02,  5.82050942e-02,  3.41961756e-02,
         5.93537465e-02, -4.88524139e-03,  6.03169389e-02,
         1.83362126e-01],
       [-1.74306586e-01,  8.50927830e-03,  5.59768602e-02,
         6.93600252e-03,  1.01455852e-01,  1.91212326e-01,
         1.01702012e-01, -4.06461619e-02,  7.65661225e-02,
         5.99906780e-02],
       [-1.77252740e-01, -8.87014568e-02,  2.68679634e-02,
        -9.68291797e-03,  3.35638374e-02,  5.05711734e-02,
        -8.12163353e-02, -1.41850561e-01,  1.26373529e-01,
         1.44208223e-01],
       [ 8.87322426e-03,  4.11978438e-02, -1.41141713e-02,
         4.23451513e-02,  1.00925714e-01,  1.82571277e-01,
         1.03720605e-01,  7.12942332e-03,  8.42842646e-03,
         1.58599243e-01],
       [-2.13406682e-02, -1.75910015e-02, -5.12900352e-02,
         2.48414166e-02,  1.27789266e-02,  1.57729238e-01,
        -9.69544426e-03, -8.99763852e-02,  1.03400812e-01,
         1.43311873e-01],
       [-8.07911009e-02, -7.29364436e-03,  8.38935375e-03,
        -2.07424574e-02,  5.76308332e-02,  1.60505801e-01,
         1.78787038e-02, -3.16766389e-02,  9.31937695e-02,
        -1.85308605e-02],
       [-9.09739435e-02,  9.12447274e-02, -2.94412468e-02,
        -4.54112887e-05,  5.61082549e-02,  1.87468201e-01,
         1.19883358e-01, -1.82710923e-02,  6.45193905e-02,
         2.49637775e-02],
       [-1.29598469e-01, -3.25308368e-03, -4.12912332e-02,
        -5.08361273e-02, -1.04823690e-02,  1.82079896e-02,
        -6.16563670e-02, -7.23700672e-02,  1.43699115e-02,
         1.14841349e-01],
       [-6.54844195e-02,  1.32181123e-02, -1.93353333e-02,
         3.21756229e-02,  1.93038825e-02,  6.58847988e-02,
        -8.00077915e-02, -8.77422094e-02,  6.08410016e-02,
         7.79580325e-02],
       [-1.41706705e-01, -2.03647800e-02, -7.27154315e-03,
         6.24460652e-02,  4.42662910e-02,  3.27514187e-02,
        -1.90984663e-02, -2.80730426e-02,  6.94331005e-02,
         1.08589470e-01],
       [-1.47680700e-01,  1.48287900e-02,  4.51873392e-02,
         9.84134525e-03, -3.12292501e-02,  7.53606334e-02,
        -3.28539237e-02, -8.99436623e-02,  9.34160203e-02,
         1.19600557e-01],
       [-1.06706396e-01, -7.91918635e-02, -3.81190814e-02,
         7.73421675e-03,  5.89197949e-02,  1.48285478e-01,
         7.33993948e-04, -7.37952963e-02,  7.03667626e-02,
         1.29369870e-02],
       [-1.18764386e-01,  6.67049065e-02, -1.84375793e-03,
         9.65175256e-02,  2.87358947e-02,  1.37036711e-01,
        -1.75398681e-02, -3.47253568e-02, -1.17472522e-02,
         1.64805949e-01],
       [-2.96032578e-02,  7.04797730e-02, -5.30908853e-02,
         3.15117575e-02,  1.10758804e-02,  1.52998209e-01,
        -4.00629006e-02, -8.53683874e-02,  8.19639489e-02,
         8.44380781e-02],
       [ 5.15967757e-02, -7.42607191e-03, -5.11791557e-03,
         6.67313561e-02,  5.96174449e-02, -7.76399858e-03,
         1.19835325e-01,  2.59960741e-02,  3.64723615e-02,
         2.43333012e-01],
       [-1.81520402e-01, -2.21860819e-02, -9.37249511e-04,
        -4.36494425e-02,  1.48944080e-01,  1.20623425e-01,
         9.27821100e-02,  8.10635090e-03,  1.05029367e-01,
         5.49212396e-02],
       [-1.62024647e-01, -7.55626708e-03,  8.90679806e-02,
         1.09557875e-01,  4.00451683e-02,  5.88794537e-02,
        -5.02957255e-02, -9.45830047e-02,  6.43425062e-02,
         6.01378679e-02],
       [ 2.44281888e-02, -2.44084001e-03,  7.41914511e-02,
         1.29617304e-01,  7.07155839e-03,  6.66829422e-02,
         2.52056420e-02, -9.07487422e-02,  5.63489906e-02,
         1.43779904e-01],
       [ 4.61201221e-02, -6.86585605e-02, -7.71630462e-03,
        -2.53528468e-02,  1.99609250e-02,  9.59918946e-02,
         1.04020424e-02, -5.57698309e-02,  2.24557221e-02,
         1.73324734e-01],
       [-8.96764472e-02, -2.28214562e-02,  1.30368788e-02,
         6.01692796e-02,  1.63236298e-02,  9.95141268e-02,
        -1.29274517e-01, -2.46923715e-02,  3.26206349e-02,
         8.57910439e-02]], dtype=float32)>}

Puoi anche caricare ed eseguire l'inferenza in modo distribuito:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

La chiamata alla funzione ripristinata è solo un passaggio in avanti sul modello salvato (previsione). E se volessi continuare ad addestrare la funzione caricata? O incorporare la funzione caricata in un modello più grande? Una pratica comune è avvolgere questo oggetto caricato in un livello Keras per ottenere questo risultato. Fortunatamente, TF Hub ha hub.KerasLayer per questo scopo, mostrato qui:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.1955 - sparse_categorical_accuracy: 0.9423
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0628 - sparse_categorical_accuracy: 0.9815

Come puoi vedere, hub.KerasLayer avvolge il risultato caricato da tf.saved_model.load() in un livello Keras che può essere utilizzato per costruire un altro modello. Questo è molto utile per trasferire l'apprendimento.

Quale API dovrei usare?

Per il salvataggio, se stai lavorando con un modello Keras, è quasi sempre consigliato utilizzare l'API model.save() di Keras. Se ciò che stai salvando non è un modello Keras, l'API di livello inferiore è la tua unica scelta.

Per il caricamento, l'API che utilizzi dipende da ciò che desideri ottenere dall'API di caricamento. Se non puoi (o non vuoi) ottenere un modello Keras, usa tf.saved_model.load() . Altrimenti, usa tf.keras.models.load_model() . Nota che puoi recuperare un modello Keras solo se hai salvato un modello Keras.

È possibile combinare e abbinare le API. Puoi salvare un modello Keras con model.save e caricare un modello non Keras con l'API di basso livello, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Salvataggio/Caricamento da dispositivo locale

Quando si salva e si carica da un dispositivo io locale durante l'esecuzione in remoto, ad esempio utilizzando un TPU cloud, è necessario utilizzare l'opzione experimental_io_device per impostare il dispositivo io su localhost.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Avvertenze

Un caso speciale è quando hai un modello Keras che non ha input ben definiti. Ad esempio, un modello sequenziale può essere creato senza alcuna forma di input ( Sequential([Dense(3), ...] ). Inoltre, i modelli sottoclasse non hanno input ben definiti dopo l'inizializzazione. In questo caso, dovresti attenerti al API di livello inferiore sia sul salvataggio che sul caricamento, altrimenti otterrai un errore.

Per verificare se il tuo modello ha input ben definiti, controlla se model.inputs è None . Se non è None , siete tutti a posto. Le forme di input vengono definite automaticamente quando il modello viene utilizzato in .fit , .evaluate , .predict o quando si chiama il modello ( model(inputs) ).

Ecco un esempio:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets