Bir dağıtım stratejisi kullanarak bir modeli kaydedin ve yükleyin

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

genel bakış

Eğitim sırasında bir modeli kaydetmek ve yüklemek yaygındır. Bir keras modelini kaydetmek ve yüklemek için iki set API vardır: yüksek seviyeli API ve düşük seviyeli API. Bu öğretici, tf.distribute.Strategy kullanırken SavedModel API'lerini nasıl kullanabileceğinizi tf.distribute.Strategy . SavedModel ve genel olarak serileştirme hakkında bilgi edinmek için lütfen kaydedilmiş model kılavuzunu ve Keras model serileştirme kılavuzunu okuyun . Basit bir örnekle başlayalım:

Bağımlılıkları içe aktar:

import tensorflow_datasets as tfds

import tensorflow as tf

Verileri ve modeli tf.distribute.Strategy kullanarak tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Modeli eğitin:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 10s 4ms/step - loss: 0.2033 - sparse_categorical_accuracy: 0.9408
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0644 - sparse_categorical_accuracy: 0.9812
<tensorflow.python.keras.callbacks.History at 0x7f07cc109790>

Modeli kaydedin ve yükleyin

Artık çalışabileceğiniz basit bir modeliniz olduğuna göre, API'leri kaydetme/yükleme işlemlerine bir göz atalım. Kullanılabilir iki API grubu vardır:

Keras API'leri

Keras API'leri ile bir modeli kaydetme ve yükleme örneği:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Modeli tf.distribute.Strategy olmadan geri yükleyin:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0482 - sparse_categorical_accuracy: 0.9849
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9896
<tensorflow.python.keras.callbacks.History at 0x7f08e8347e90>

Modeli geri yükledikten sonra, kaydetmeden önce zaten derlenmiş olduğundan, tekrar compile() çağırmanıza gerek kalmadan üzerinde eğitime devam edebilirsiniz. Model, TensorFlow'un standart SavedModel proto biçiminde kaydedilir. Daha fazla bilgi için lütfen saved_model formatı kılavuzuna bakın.

Şimdi modeli yüklemek ve bir tf.distribute.Strategy kullanarak eğitmek için:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9856
Epoch 2/2
938/938 [==============================] - 8s 8ms/step - loss: 0.0337 - sparse_categorical_accuracy: 0.9898

Gördüğünüz gibi, yükleme tf.distribute.Strategy ile beklendiği gibi tf.distribute.Strategy . Burada kullanılan strateji, kaydetmeden önce kullanılan stratejiyle aynı olmak zorunda değildir.

tf.saved_model API'leri

Şimdi alt seviye API'lere bir göz atalım. Modeli kaydetmek, keras API'sine benzer:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

tf.saved_model.load() ile tf.saved_model.load() yapılabilir. Ancak, daha düşük seviyede bir API olduğundan (ve dolayısıyla daha geniş bir kullanım durumu yelpazesine sahiptir), bir Keras modeli döndürmez. Bunun yerine, çıkarım yapmak için kullanılabilecek işlevleri içeren bir nesne döndürür. Örneğin:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

Yüklenen nesne, her biri bir anahtarla ilişkilendirilmiş birden çok işlev içerebilir. "serving_default" , kaydedilmiş bir Keras modeliyle çıkarsama işlevi için varsayılan anahtardır. Bu işlevle bir çıkarım yapmak için:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-3.87175977e-02, -1.61857940e-02,  4.99733090e-02,
         1.36069462e-01,  1.08132266e-01,  7.42500275e-02,
         7.73677081e-02, -4.71051671e-02,  6.41952083e-02,
         1.25855744e-01],
       [-6.54747933e-02,  2.14797929e-01, -8.99797529e-02,
        -7.83173367e-04,  8.92944187e-02,  1.73829526e-01,
         1.60048857e-01,  6.56097680e-02,  9.97588038e-03,
         1.09826915e-01],
       [-1.54433712e-01,  8.42373446e-03,  5.25225215e-02,
         4.81541678e-02,  2.23428048e-02,  7.69063532e-02,
        -3.99671011e-02, -4.56253998e-02,  5.28362878e-02,
         1.42020926e-01],
       [-9.43576694e-02, -9.26742554e-02,  7.42801651e-02,
         7.15272427e-02,  4.92905118e-02,  6.36474639e-02,
        -4.97807935e-03, -8.52353871e-02,  4.42789420e-02,
         1.56496197e-01],
       [-2.41690576e-02, -1.11253485e-01,  2.60174796e-02,
        -2.53655966e-02, -3.06888595e-02,  7.04614669e-02,
        -2.01677606e-02, -8.55238289e-02,  4.99680266e-02,
         1.76511839e-01],
       [-4.30272967e-02, -2.32194141e-02, -4.67026755e-02,
         5.49829267e-02,  3.00423354e-02,  1.36088327e-01,
         2.45837178e-02, -3.89227420e-02,  4.13372666e-02,
         7.04330802e-02],
       [-7.81589746e-03,  8.88681635e-02,  1.90377012e-02,
         2.47566961e-02,  6.73517808e-02,  9.59524438e-02,
         9.72996876e-02, -6.45323768e-02, -5.78478463e-02,
         1.67243421e-01],
       [ 2.59691626e-02,  2.44912747e-02,  2.06067227e-03,
         1.70634389e-02,  2.18513757e-02,  1.37382165e-01,
         2.51548998e-02,  9.26546752e-03,  9.15516615e-02,
         1.02677807e-01],
       [-1.27567410e-01,  2.68350840e-02,  1.16995871e-02,
         9.84568521e-02,  5.41514680e-02,  1.51651859e-01,
         4.64795753e-02, -2.91735604e-02,  1.25106841e-01,
         1.71144456e-01],
       [-8.05673450e-02,  2.94806808e-02, -1.10152550e-01,
        -7.07100704e-02, -1.24747120e-02,  1.40526205e-01,
        -1.23536028e-02, -5.82108870e-02,  7.77795464e-02,
        -2.34171711e-02],
       [ 6.20728806e-02,  2.07295828e-02, -6.43237084e-02,
        -2.05786452e-02,  4.67800628e-03,  8.61115605e-02,
         4.20134440e-02, -4.34638187e-02,  3.63051891e-03,
         7.10046589e-02],
       [-6.08086586e-04,  3.42310257e-02, -8.34690034e-02,
         9.08722728e-03,  1.05238423e-01,  1.47501156e-01,
         1.42161340e-01,  6.30808845e-02, -5.24729490e-04,
         1.36735335e-01],
       [-1.91496432e-01, -3.44831869e-03, -1.86069421e-02,
        -5.05818650e-02,  6.41704351e-03,  4.03927974e-02,
         2.21725628e-02, -1.62169725e-01,  1.76057145e-02,
         1.85961321e-01],
       [-2.49194503e-02, -2.50838101e-02, -3.03340796e-02,
         2.27186158e-02, -3.09922658e-02,  1.24549076e-01,
        -2.09298395e-02, -3.32995206e-02,  5.32348230e-02,
         1.50542200e-01],
       [-4.28871512e-02, -9.22830626e-02,  2.90450491e-02,
         1.12102680e-01,  3.04660778e-02,  7.70893693e-02,
         6.90545812e-02, -3.36272307e-02,  5.62672317e-02,
         2.48093754e-01],
       [-1.07569546e-02, -7.48349279e-02,  2.02673525e-02,
         8.59332830e-02,  1.17068969e-01,  7.48098046e-02,
         6.08244948e-02,  1.01280659e-02,  1.91551913e-02,
         2.11404741e-01],
       [ 7.62790143e-02,  2.29362398e-03, -5.91743030e-02,
         1.31748244e-02,  3.23122516e-02,  7.41088167e-02,
        -6.43679202e-02, -4.93042730e-03,  7.44977593e-02,
         1.33817106e-01],
       [-1.31239325e-01, -1.13535374e-02, -4.06663343e-02,
        -1.28620975e-02, -5.99850900e-04,  1.77248433e-01,
        -1.11874640e-02, -1.12869248e-01,  2.11616959e-02,
         1.17670909e-01],
       [-7.98342153e-02,  1.17193013e-02, -5.84214628e-02,
         4.45103496e-02, -2.35349871e-03,  9.78477970e-02,
        -1.31990746e-01, -9.62637737e-02, -4.68819216e-03,
         4.20698449e-02],
       [ 2.33993679e-02, -6.48612902e-02,  3.57650109e-02,
         4.70414013e-02, -9.87344980e-03,  8.05342346e-02,
         8.79075974e-02, -3.31879668e-02,  4.50132787e-03,
         1.80063233e-01],
       [-4.68377843e-02, -4.20477986e-03, -6.34167045e-02,
         1.02471206e-02,  7.62451440e-02,  4.69596758e-02,
         1.50972791e-02,  2.54708696e-02,  6.22914918e-02,
        -7.17046950e-03],
       [-7.52374083e-02, -5.75587116e-02,  5.98922297e-02,
        -2.73536984e-02, -9.00403410e-03,  1.24273732e-01,
         3.02318707e-02, -8.23312476e-02,  4.71908152e-02,
         6.14355356e-02],
       [-1.15195602e-01, -1.64229050e-02,  1.29237305e-03,
        -1.09596789e-01,  3.80243734e-02,  1.29902706e-01,
         2.34013163e-02, -1.05847850e-01,  5.61432168e-02,
         6.71504661e-02],
       [-6.54574633e-02,  1.53685883e-01, -9.75284129e-02,
         7.12753274e-03,  8.82450342e-02,  1.34998679e-01,
         1.54644206e-01,  6.79479688e-02, -1.72668919e-02,
        -9.18642431e-03],
       [-2.10184306e-01,  1.36045277e-01, -8.87301117e-02,
        -1.74187630e-01,  4.86502573e-02,  2.76032418e-01,
         6.05597720e-02, -7.86263272e-02,  1.03168473e-01,
        -1.37950957e-01],
       [-1.14111096e-01,  2.64648199e-02,  3.91238183e-02,
         1.27296187e-02,  1.35679528e-01,  1.77840948e-01,
        -1.35071129e-02,  2.69818418e-02,  3.42153087e-02,
         8.82407650e-02],
       [-3.01946551e-02, -2.40282975e-02, -6.38710558e-02,
        -3.98451164e-02, -2.92418208e-02,  1.03236027e-01,
        -6.42294213e-02, -6.60998672e-02,  8.01952481e-02,
         4.51336950e-02],
       [-7.84322768e-02,  3.11397053e-02, -5.89536875e-02,
         5.54301403e-03,  4.19510715e-02,  1.42857507e-01,
         3.19332965e-02,  8.02233815e-03,  4.79588211e-02,
         3.57780866e-02],
       [ 4.59235013e-02, -5.12123033e-02,  7.95385391e-02,
        -2.55374219e-02,  4.68068533e-02,  4.43704054e-02,
         7.18893707e-02,  1.64316930e-02,  1.46261796e-01,
         1.57745421e-01],
       [-7.41346031e-02, -2.22236216e-02,  5.15953489e-02,
         3.52352336e-02,  3.60111147e-02,  1.15427203e-01,
         2.57897153e-02, -1.20676845e-01,  5.01920506e-02,
         5.08756042e-02],
       [-8.63073617e-02,  3.17134708e-03,  7.40556046e-03,
        -4.21377048e-02, -1.40033364e-02,  1.12993240e-01,
         3.54410820e-02, -6.92673773e-03, -1.53519716e-02,
         1.26333967e-01],
       [-1.94949329e-01,  3.92537005e-02, -3.54437791e-02,
         2.31000036e-03,  3.79432067e-02,  1.76356524e-01,
        -2.87505463e-02, -1.43955618e-01,  5.22137731e-02,
         1.74010634e-01],
       [-7.06951618e-02,  4.39597517e-02, -6.54652417e-02,
         8.82245004e-02,  2.06859671e-02,  2.03700036e-01,
         4.97581288e-02, -9.99335721e-02,  1.11235633e-01,
         1.77167300e-02],
       [-3.53763551e-02,  1.12260692e-02, -1.11665837e-02,
         2.38924790e-02, -2.76037231e-02,  5.51242232e-02,
         1.64862610e-02, -7.48966262e-02,  8.12724680e-02,
         1.43957604e-03],
       [-6.88709617e-02,  8.31975043e-02, -1.29381031e-01,
         3.24329957e-02,  7.93049186e-02,  1.07140720e-01,
        -1.81627274e-03, -6.15758151e-02, -4.92713787e-03,
         1.36734053e-01],
       [-1.40681088e-01, -3.15807723e-02,  2.60454044e-02,
        -4.27230299e-02,  9.93238837e-02,  8.53562057e-02,
        -5.72501905e-02, -4.54951562e-02,  5.40010631e-02,
         1.60534859e-01],
       [-9.79049355e-02,  6.72584400e-03,  8.45115632e-03,
         6.96118921e-04,  6.83328062e-02,  7.55404532e-02,
         4.08902764e-02, -4.68667373e-02, -1.57676004e-02,
         1.63363129e-01],
       [ 8.52064192e-02, -4.55360785e-02, -1.45583451e-02,
         3.04546617e-02,  8.79652798e-04,  4.33428399e-02,
        -1.49330217e-02, -8.25205147e-02,  6.79182187e-02,
         1.38463050e-01],
       [-6.81671500e-03,  2.08939444e-02, -4.36149202e-02,
         4.71097752e-02, -4.12514098e-02,  7.67979622e-02,
        -3.79576795e-02, -5.68123832e-02,  8.52056891e-02,
         8.15685652e-03],
       [-3.65775973e-02, -3.06402445e-02,  6.65950254e-02,
         4.47403751e-02,  1.30895078e-01,  1.50094643e-01,
        -7.96291754e-02, -2.53046192e-02,  1.07343026e-01,
         1.23134412e-01],
       [-7.62542933e-02,  8.05391073e-02, -4.92790192e-02,
         9.18199718e-02, -2.56493650e-02,  1.11088946e-01,
         1.23991318e-01, -7.41875842e-02,  1.47743657e-01,
         1.12662911e-01],
       [-1.30924150e-01, -4.17556763e-02,  9.19794068e-02,
         1.42487824e-01,  5.94666004e-02,  8.11570883e-02,
         8.41576308e-02, -9.05135348e-02,  6.55059814e-02,
         1.40432447e-01],
       [-1.17066503e-02,  2.23872401e-02,  8.04524869e-02,
         4.16423976e-02,  2.44700648e-02,  1.15135796e-01,
        -4.52195890e-02, -3.85435522e-02,  5.72163910e-02,
        -1.70241687e-02],
       [-1.59098431e-01, -2.48743258e-02,  1.15704350e-03,
        -1.94135085e-02,  7.07440227e-02,  1.02726325e-01,
         7.48141706e-02, -4.92161885e-02, -4.82953712e-03,
         7.76504800e-02],
       [-5.23527563e-02, -6.95675313e-02,  6.23273998e-02,
         2.89142895e-02,  5.82050942e-02,  3.41961756e-02,
         5.93537465e-02, -4.88524139e-03,  6.03169389e-02,
         1.83362126e-01],
       [-1.74306586e-01,  8.50927830e-03,  5.59768602e-02,
         6.93600252e-03,  1.01455852e-01,  1.91212326e-01,
         1.01702012e-01, -4.06461619e-02,  7.65661225e-02,
         5.99906780e-02],
       [-1.77252740e-01, -8.87014568e-02,  2.68679634e-02,
        -9.68291797e-03,  3.35638374e-02,  5.05711734e-02,
        -8.12163353e-02, -1.41850561e-01,  1.26373529e-01,
         1.44208223e-01],
       [ 8.87322426e-03,  4.11978438e-02, -1.41141713e-02,
         4.23451513e-02,  1.00925714e-01,  1.82571277e-01,
         1.03720605e-01,  7.12942332e-03,  8.42842646e-03,
         1.58599243e-01],
       [-2.13406682e-02, -1.75910015e-02, -5.12900352e-02,
         2.48414166e-02,  1.27789266e-02,  1.57729238e-01,
        -9.69544426e-03, -8.99763852e-02,  1.03400812e-01,
         1.43311873e-01],
       [-8.07911009e-02, -7.29364436e-03,  8.38935375e-03,
        -2.07424574e-02,  5.76308332e-02,  1.60505801e-01,
         1.78787038e-02, -3.16766389e-02,  9.31937695e-02,
        -1.85308605e-02],
       [-9.09739435e-02,  9.12447274e-02, -2.94412468e-02,
        -4.54112887e-05,  5.61082549e-02,  1.87468201e-01,
         1.19883358e-01, -1.82710923e-02,  6.45193905e-02,
         2.49637775e-02],
       [-1.29598469e-01, -3.25308368e-03, -4.12912332e-02,
        -5.08361273e-02, -1.04823690e-02,  1.82079896e-02,
        -6.16563670e-02, -7.23700672e-02,  1.43699115e-02,
         1.14841349e-01],
       [-6.54844195e-02,  1.32181123e-02, -1.93353333e-02,
         3.21756229e-02,  1.93038825e-02,  6.58847988e-02,
        -8.00077915e-02, -8.77422094e-02,  6.08410016e-02,
         7.79580325e-02],
       [-1.41706705e-01, -2.03647800e-02, -7.27154315e-03,
         6.24460652e-02,  4.42662910e-02,  3.27514187e-02,
        -1.90984663e-02, -2.80730426e-02,  6.94331005e-02,
         1.08589470e-01],
       [-1.47680700e-01,  1.48287900e-02,  4.51873392e-02,
         9.84134525e-03, -3.12292501e-02,  7.53606334e-02,
        -3.28539237e-02, -8.99436623e-02,  9.34160203e-02,
         1.19600557e-01],
       [-1.06706396e-01, -7.91918635e-02, -3.81190814e-02,
         7.73421675e-03,  5.89197949e-02,  1.48285478e-01,
         7.33993948e-04, -7.37952963e-02,  7.03667626e-02,
         1.29369870e-02],
       [-1.18764386e-01,  6.67049065e-02, -1.84375793e-03,
         9.65175256e-02,  2.87358947e-02,  1.37036711e-01,
        -1.75398681e-02, -3.47253568e-02, -1.17472522e-02,
         1.64805949e-01],
       [-2.96032578e-02,  7.04797730e-02, -5.30908853e-02,
         3.15117575e-02,  1.10758804e-02,  1.52998209e-01,
        -4.00629006e-02, -8.53683874e-02,  8.19639489e-02,
         8.44380781e-02],
       [ 5.15967757e-02, -7.42607191e-03, -5.11791557e-03,
         6.67313561e-02,  5.96174449e-02, -7.76399858e-03,
         1.19835325e-01,  2.59960741e-02,  3.64723615e-02,
         2.43333012e-01],
       [-1.81520402e-01, -2.21860819e-02, -9.37249511e-04,
        -4.36494425e-02,  1.48944080e-01,  1.20623425e-01,
         9.27821100e-02,  8.10635090e-03,  1.05029367e-01,
         5.49212396e-02],
       [-1.62024647e-01, -7.55626708e-03,  8.90679806e-02,
         1.09557875e-01,  4.00451683e-02,  5.88794537e-02,
        -5.02957255e-02, -9.45830047e-02,  6.43425062e-02,
         6.01378679e-02],
       [ 2.44281888e-02, -2.44084001e-03,  7.41914511e-02,
         1.29617304e-01,  7.07155839e-03,  6.66829422e-02,
         2.52056420e-02, -9.07487422e-02,  5.63489906e-02,
         1.43779904e-01],
       [ 4.61201221e-02, -6.86585605e-02, -7.71630462e-03,
        -2.53528468e-02,  1.99609250e-02,  9.59918946e-02,
         1.04020424e-02, -5.57698309e-02,  2.24557221e-02,
         1.73324734e-01],
       [-8.96764472e-02, -2.28214562e-02,  1.30368788e-02,
         6.01692796e-02,  1.63236298e-02,  9.95141268e-02,
        -1.29274517e-01, -2.46923715e-02,  3.26206349e-02,
         8.57910439e-02]], dtype=float32)>}

Ayrıca dağıtılmış bir şekilde yükleyebilir ve çıkarım yapabilirsiniz:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Geri yüklenen işlevin çağrılması, kaydedilen modelde (tahmin) yalnızca ileriye doğru bir geçiştir. Yüklenen işlevi eğitmeye devam etmek isterseniz ne olur? Veya yüklenen işlevi daha büyük bir modele gömmek mi? Yaygın bir uygulama, bunu başarmak için bu yüklenen nesneyi bir Keras katmanına sarmaktır. Neyse ki, TF Hub'ın burada gösterilen bu amaç için hub.KerasLayer'ı vardır:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.1955 - sparse_categorical_accuracy: 0.9423
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0628 - sparse_categorical_accuracy: 0.9815

Gördüğünüz gibi, hub.KerasLayer , tf.saved_model.load() dan yüklenen sonucu başka bir model oluşturmak için kullanılabilecek bir tf.saved_model.load() hub.KerasLayer sarar. Bu transfer öğrenme için çok yararlıdır.

Hangi API'yi kullanmalıyım?

Tasarruf için, bir keras modeliyle çalışıyorsanız, hemen hemen her zaman model.save() API'sini kullanmanız önerilir. Kaydettiğiniz şey bir Keras modeli değilse, tek seçeneğiniz daha düşük seviyeli API'dir.

Yükleme için hangi API'yi kullanacağınız, yükleme API'sinden ne almak istediğinize bağlıdır. Bir tf.saved_model.load() modeli alamıyorsanız (veya istemiyorsanız), tf.saved_model.load() kullanın. Aksi takdirde, tf.keras.models.load_model() kullanın. Bir Keras modelini yalnızca bir Keras modelini kaydettiyseniz geri alabileceğinizi unutmayın.

API'leri karıştırmak ve eşleştirmek mümkündür. model.save ile bir model.save modelini kaydedebilir ve düşük seviyeli API, tf.saved_model.load ile model.save olmayan bir modeli tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Yerel cihazdan kaydetme/yükleme

Uzaktan çalışırken, örneğin bir bulut TPU kullanarak yerel bir io cihazından kaydederken ve yüklerken, io cihazını localhost'a ayarlamak için experimental_io_device seçeneği kullanılmalıdır.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

uyarılar

Özel bir durum, iyi tanımlanmış girdileri olmayan bir Keras modeliniz olduğunda ortaya çıkar. Örneğin, herhangi bir girdi şekli olmadan Sıralı bir model oluşturulabilir ( Sequential([Dense(3), ...] ) Alt sınıflı modeller de başlatmadan sonra iyi tanımlanmış girdilere sahip değildir. Bu durumda, hem kaydetme hem de yükleme sırasında daha düşük seviyeli API'ler, aksi takdirde bir hata alırsınız.

Modelinizin iyi tanımlanmış girdileri olup olmadığını kontrol etmek için model.inputs öğesinin None olup olmadığını kontrol edin. None değilse, hepiniz iyisiniz. Model .fit , .evaluate , .predict içinde kullanıldığında veya modeli çağırırken ( model(inputs) ) girdi şekilleri otomatik olarak tanımlanır.

İşte bir örnek:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f07bc363410>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f076c5f2750>, because it is not built.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets