Bu sayfa, Cloud Translation API ile çevrilmiştir.
Switch to English

Dağıtım stratejisi kullanarak bir modeli kaydedin ve yükleyin

TensorFlow.org'da görüntüleyin Google Colab'de çalıştırın Kaynağı GitHub'da görüntüleyin Defteri indirin

Genel Bakış

Eğitim sırasında bir modeli kaydetmek ve yüklemek yaygındır. Bir keras modelini kaydetmek ve yüklemek için iki API seti vardır: yüksek seviyeli bir API ve düşük seviyeli bir API. Bu eğitici, tf.distribute.Strategy kullanırken SavedModel API'lerini nasıl kullanabileceğinizi tf.distribute.Strategy . SavedModel ve genel olarak serileştirme hakkında bilgi edinmek için lütfen kaydedilmiş model kılavuzunu ve Keras model serileştirme kılavuzunu okuyun . Basit bir örnekle başlayalım:

İçe aktarma bağımlılıkları:

import tensorflow_datasets as tfds

import tensorflow as tf

tf.distribute.Strategy kullanarak verileri ve modeli hazırlayın:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Modeli eğitin:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 4s 4ms/step - loss: 0.2095 - sparse_categorical_accuracy: 0.9386
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0730 - sparse_categorical_accuracy: 0.9787

<tensorflow.python.keras.callbacks.History at 0x7f7470042b38>

Modeli kaydedin ve yükleyin

Artık üzerinde çalışabileceğiniz basit bir modeliniz olduğuna göre, kaydetme / yükleme API'lerine bir göz atalım. Kullanılabilir iki API seti vardır:

Keras API'leri

Keras API'leri ile bir model kaydetme ve yükleme örneği:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Modeli tf.distribute.Strategy olmadan geri yükleyin:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0539 - sparse_categorical_accuracy: 0.9838
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0381 - sparse_categorical_accuracy: 0.9884

<tensorflow.python.keras.callbacks.History at 0x7f74d333f780>

Modeli geri yükledikten sonra, kaydetmeden önce derlenmiş olduğundan compile() tekrar çağırmanıza gerek kalmadan model üzerinde eğitime devam edebilirsiniz. Model, TensorFlow'un standart SavedModel protokol biçiminde kaydedilir. Daha fazla bilgi için, lütfen saved_model formatı kılavuzuna bakın.

Şimdi modeli yüklemek ve bir tf.distribute.Strategy kullanarak eğitmek için:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 10ms/step - loss: 0.0530 - sparse_categorical_accuracy: 0.9844
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0388 - sparse_categorical_accuracy: 0.9882

Gördüğünüz gibi yükleme tf.distribute.Strategy ile beklendiği gibi tf.distribute.Strategy . Burada kullanılan strateji, kaydetmeden önce kullanılan stratejinin aynısı olmak zorunda değildir.

tf.saved_model API'leri

Şimdi alt düzey API'lere bir göz atalım. Modelin kaydedilmesi, keras API'sine benzer:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

tf.saved_model.load() ile tf.saved_model.load() yapılabilir. Bununla birlikte, daha düşük seviyede olan (ve dolayısıyla daha geniş bir kullanım alanına sahip olan) bir API olduğu için, bir Keras modeli döndürmez. Bunun yerine, çıkarım yapmak için kullanılabilecek işlevler içeren bir nesne döndürür. Örneğin:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

Yüklenen nesne, her biri bir anahtarla ilişkilendirilmiş birden çok işlevi içerebilir. "serving_default" , kaydedilmiş bir Keras modeli ile çıkarım işlevi için varsayılan anahtardır. Bu işlevle bir çıkarım yapmak için:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.46878400e-01, -2.84028575e-02,  4.34195548e-02,
         8.65758881e-02, -5.50181568e-02, -2.26117969e-02,
        -8.18806365e-02,  1.60868585e-01,  7.05277026e-02,
        -2.11526364e-01],
       [-2.04405725e-01, -2.38965377e-02,  1.06097549e-01,
         1.15776211e-02, -5.68305999e-02,  7.61558264e-02,
        -2.36685127e-02,  6.12710230e-02,  6.85455352e-02,
        -2.04084530e-01],
       [-1.70060426e-01,  6.82905912e-02, -2.54967008e-02,
         1.27377272e-01, -4.24135383e-03, -1.15118716e-02,
         1.65115029e-01,  1.64797649e-01,  8.41001868e-02,
        -2.60865986e-01],
       [-1.24608956e-01,  7.05861971e-02,  4.76837084e-02,
         9.51382518e-02, -1.36017501e-02,  9.53883678e-02,
        -2.60323286e-04,  1.26946449e-01, -9.98851806e-02,
         6.01550192e-02],
       [-8.42214674e-02, -4.93131615e-02, -5.85474074e-04,
        -3.79234888e-02, -6.78482801e-02,  9.56373289e-02,
         4.69041206e-02,  8.55031833e-02,  9.31831449e-02,
        -1.40825540e-01],
       [-1.46941900e-01,  1.22972876e-02,  5.79140112e-02,
        -7.50405565e-02,  6.13511279e-02,  1.14746153e-01,
         3.54535617e-02,  2.55915433e-01,  7.26796240e-02,
        -1.99857190e-01],
       [-2.07879156e-01,  1.83034241e-02,  1.57775074e-01,
         6.06807172e-02, -1.75382420e-02,  1.33817732e-01,
         1.36331618e-01,  2.02472329e-01,  3.72610986e-02,
        -1.31865010e-01],
       [-9.93705392e-02,  6.03869818e-02, -4.28698361e-02,
         6.31842762e-04,  8.84034038e-02,  6.72685653e-02,
        -2.09506359e-02,  1.97081745e-01,  7.39021823e-02,
        -1.64300233e-01],
       [-9.71228778e-02,  5.48233166e-02,  1.38393641e-02,
        -7.14895800e-02, -3.87909710e-02,  8.45830888e-04,
        -3.62640694e-02,  1.64835989e-01,  5.04231751e-02,
        -2.07461655e-01],
       [-2.92240772e-02,  1.45425312e-02,  5.74428178e-02,
        -1.34241190e-02, -1.80013701e-02,  7.78546855e-02,
        -8.48746449e-02,  9.98296142e-02,  6.38790280e-02,
        -5.32845445e-02],
       [-1.76605240e-01, -1.42511949e-01,  1.39559209e-01,
        -2.00123414e-02, -6.44349307e-02, -4.56911251e-02,
         2.01093405e-03,  1.59898788e-01,  1.95391588e-02,
        -1.61375850e-01],
       [-1.58091724e-01,  6.25609234e-03,  2.12391287e-01,
        -1.39106885e-01, -4.78955358e-02,  7.36434534e-02,
         7.29984716e-02,  2.28351891e-01,  1.23042218e-01,
        -2.22285807e-01],
       [-6.63312748e-02, -5.25613949e-02,  3.88407931e-02,
         4.74876724e-02, -3.56937200e-02,  1.11578718e-01,
        -8.47167745e-02,  1.54049486e-01,  8.42248723e-02,
        -9.11155120e-02],
       [-1.49975002e-01, -1.69416200e-02,  2.03275681e-03,
         3.08024809e-02, -1.28081590e-02,  1.18468963e-01,
        -7.31947795e-02,  2.10938901e-01,  5.79604283e-02,
        -1.06384277e-01],
       [-2.44300172e-01,  6.77020177e-02,  1.61827058e-02,
         9.77846682e-02, -2.14450657e-02,  8.76296014e-02,
         1.55660659e-02,  2.56645411e-01, -6.94077387e-02,
         1.82542913e-02],
       [-3.24441910e-01,  2.83106230e-02,  1.15296148e-01,
        -6.49778843e-02, -3.93164232e-02,  2.09751099e-01,
         1.58456087e-01,  2.03075439e-01,  1.45919517e-01,
        -8.07187557e-02],
       [-1.77742794e-01, -3.47406045e-02,  6.37909994e-02,
         5.72632812e-02, -1.67798519e-01, -9.77907851e-02,
        -6.33480251e-02,  5.98776974e-02, -1.48319647e-01,
        -3.26665044e-02],
       [-1.92516297e-02, -4.32192907e-02,  9.45950896e-02,
        -1.24730960e-01,  3.15439701e-03,  7.49434829e-02,
         1.42610222e-01,  1.64739519e-01,  1.35794416e-01,
        -2.33872890e-01],
       [-9.74408463e-02, -4.51198146e-02, -7.16688111e-02,
         1.52820855e-01,  3.08901221e-02, -8.07915181e-02,
        -8.59454572e-02,  1.73750147e-01, -4.14928459e-02,
        -1.02175683e-01],
       [-1.79451153e-01,  7.97335058e-02,  6.08496368e-02,
        -8.74251127e-05,  1.40254274e-02,  7.78948367e-02,
         1.22523680e-02,  1.38402849e-01, -2.44962424e-03,
        -8.56248587e-02],
       [-7.16196820e-02, -3.66464853e-02, -1.97902359e-02,
        -3.42466384e-02,  1.01994909e-02,  8.11903924e-02,
         1.02423221e-01,  8.15625191e-02,  9.28392410e-02,
        -1.61639646e-01],
       [-1.29672050e-01, -9.39578265e-02, -3.77402268e-02,
        -5.66408038e-03,  2.01772340e-02, -5.53961843e-04,
         1.12603299e-01,  1.18293904e-01,  7.59286210e-02,
        -1.05032220e-01],
       [ 3.13648432e-02,  2.04140544e-02,  8.68844241e-02,
         8.54840502e-03, -3.24598253e-02,  7.13473856e-02,
         1.01958007e-01,  1.58244759e-01,  4.33884151e-02,
        -1.56489074e-01],
       [-5.69176152e-02, -8.68148059e-02,  5.83150014e-02,
        -6.94776773e-02, -1.14257783e-01,  9.14709717e-02,
        -6.18093796e-02,  4.60445434e-02,  6.21100292e-02,
        -2.56335258e-01],
       [-1.00941956e-03, -9.87592638e-02,  1.59144640e-01,
         2.46649459e-02, -1.47723123e-01,  3.34706903e-03,
        -1.25270292e-01,  7.13937655e-02, -3.65925357e-02,
        -2.86379248e-01],
       [-2.52649784e-01, -1.80219673e-02,  1.53900415e-01,
        -7.60671049e-02, -4.30139415e-02,  6.14799336e-02,
         5.27559966e-02,  3.91793013e-01,  1.10363506e-01,
        -2.21582249e-01],
       [-1.04441456e-02, -5.70102595e-02, -5.45391962e-02,
        -6.66194037e-02,  3.30452994e-02,  4.31669690e-03,
        -1.39387622e-02,  1.50821537e-01,  7.82721266e-02,
        -1.13290384e-01],
       [-1.50469467e-01, -1.50829509e-01,  1.37116134e-01,
        -7.71817416e-02, -1.22132301e-01,  8.29393342e-02,
         7.44771212e-03,  1.10161960e-01,  5.23409843e-02,
        -1.67824954e-01],
       [-1.67705536e-01, -1.61053427e-02,  3.56741399e-02,
        -8.12948644e-02, -2.15860698e-02,  7.68682212e-02,
         3.90296578e-02,  8.14016312e-02,  1.20665669e-01,
        -5.40915243e-02],
       [-1.74987361e-01,  5.39990142e-03,  7.59589747e-02,
         1.13510445e-01, -3.19063663e-02, -5.98092973e-02,
        -4.05801088e-02,  2.37588376e-01, -6.73733801e-02,
        -1.72320567e-02],
       [-1.80301860e-01,  2.00746767e-02, -7.40496814e-03,
         8.36828053e-02,  9.17709470e-02,  1.46025598e-01,
        -2.91051138e-02,  2.14360297e-01, -3.91696244e-02,
        -1.15331344e-01],
       [-7.45102018e-02,  3.96583155e-02,  8.10021013e-02,
         1.56707764e-02, -2.35380158e-02,  1.56681970e-01,
        -1.12800300e-02,  3.64681214e-01,  1.12793013e-01,
        -9.20613408e-02],
       [-1.10700965e-01, -3.84411961e-03,  7.15886354e-02,
        -5.16710430e-03, -2.68637538e-02, -4.64520939e-02,
        -1.02423206e-01,  1.41418934e-01,  1.36580504e-02,
        -2.16841191e-01],
       [-1.03602912e-02, -1.36248600e-02, -8.44807327e-02,
        -3.93018406e-03,  6.54329583e-02, -1.54229663e-02,
        -9.10714716e-02,  1.13576502e-02,  6.24551401e-02,
        -1.10215969e-01],
       [-1.64637700e-01, -4.25843447e-02, -6.63272589e-02,
         1.01544857e-02,  9.00160298e-02,  1.41169682e-01,
         9.43019092e-02,  1.50300652e-01,  1.17022656e-01,
        -2.61101604e-01],
       [-2.96755701e-01,  1.48339659e-01,  5.29592186e-02,
         4.51779664e-02, -6.84008598e-02,  1.29287004e-01,
         1.34066977e-02,  1.68794006e-01, -1.53631158e-02,
        -1.40826374e-01],
       [-2.27824658e-01, -3.58637236e-02,  7.98013210e-02,
        -2.93148141e-02, -1.29889801e-01,  1.07304119e-02,
         6.16377033e-02,  2.38016129e-01,  1.68460131e-01,
        -2.78131723e-01],
       [-1.97686747e-01, -1.20533034e-01,  1.91476271e-02,
        -2.50333622e-02, -1.20231688e-01, -1.43363982e-01,
        -5.45644462e-02,  1.13663480e-01, -9.71207619e-02,
        -7.38224685e-02],
       [-1.21181801e-01, -9.18156952e-02,  1.72619522e-02,
         7.20846877e-02, -5.00237271e-02, -7.88232982e-02,
        -2.75398232e-02,  9.42765027e-02, -8.18064660e-02,
        -4.43772227e-02],
       [-2.12152809e-01, -1.05831539e-02,  1.12541884e-01,
         3.79703306e-02, -4.97136004e-02, -8.26531351e-02,
         4.28089425e-02,  2.72401571e-01, -9.41082910e-02,
        -8.25358368e-03],
       [-2.12490350e-01,  5.10787666e-02, -4.91231680e-03,
         1.71558380e-01,  8.33496898e-02,  8.03120583e-02,
         5.97136915e-02,  2.78716445e-01, -5.66011816e-02,
        -7.99765587e-02],
       [-2.45497763e-01, -5.21367639e-02,  1.77163050e-01,
         8.67958441e-02, -1.33168459e-01,  9.83412005e-03,
        -1.34591311e-01,  1.48744047e-01, -6.65533617e-02,
        -1.07505932e-01],
       [-1.36525869e-01, -5.12802340e-02,  2.54329219e-02,
         8.01228657e-02, -3.24120894e-02, -6.36913255e-03,
        -7.75915161e-02,  1.81387305e-01,  6.72850609e-02,
        -1.06104709e-01],
       [-8.19087848e-02, -6.67821616e-02,  1.09396182e-01,
        -8.99944529e-02, -1.08385280e-01,  6.29347712e-02,
         7.26154894e-02,  1.68957621e-01,  1.90485001e-01,
        -2.60798335e-01],
       [-1.76897705e-01,  4.90825251e-02,  2.94402167e-02,
        -2.41212249e-02,  3.94896790e-02,  1.18754521e-01,
         1.69773921e-02,  1.10196158e-01,  7.08303824e-02,
        -6.86142594e-02],
       [-1.29656106e-01, -8.14089552e-02,  1.14682741e-01,
        -1.32834181e-01, -1.49253279e-01, -2.83164792e-02,
         3.45680863e-04,  2.52322882e-01,  2.89388448e-02,
        -2.79281288e-01],
       [-1.10502213e-01,  1.07094124e-01,  3.24486196e-02,
         7.70951509e-02, -6.27939776e-02,  1.68845624e-01,
        -1.44310594e-01,  1.45337492e-01,  2.03377791e-02,
        -5.04231378e-02],
       [-2.66523331e-01, -7.49082193e-02,  1.91363335e-01,
        -6.39847219e-02, -1.04055285e-01,  8.31385702e-02,
         8.82939398e-02,  1.99207246e-01,  5.35239354e-02,
        -2.60884434e-01],
       [-1.35722771e-01,  3.94147262e-02, -6.39424995e-02,
         1.39283150e-01,  5.37211001e-02, -6.34303223e-03,
        -1.70467123e-01,  2.55692095e-01, -7.66103566e-02,
        -6.90388680e-02],
       [-1.07885860e-01,  2.30858717e-02,  8.21547359e-02,
        -3.12240291e-02, -9.89983678e-02,  7.22398609e-02,
        -4.08478230e-02,  8.69123414e-02,  4.48577479e-02,
        -6.41947538e-02],
       [-2.28321850e-02, -3.88411283e-02,  1.47033811e-01,
        -2.35385150e-01, -9.87000838e-02,  6.44287840e-02,
        -1.87633559e-02,  1.17905587e-01,  9.70625877e-02,
        -2.46781930e-01],
       [-8.77917856e-02, -1.64044406e-02,  7.53755122e-02,
        -8.24043527e-04, -7.77238905e-02,  1.16269790e-01,
        -1.00877963e-01,  8.79124254e-02,  3.39440927e-02,
        -5.94997481e-02],
       [-1.41677827e-01, -1.40151009e-02,  8.84927809e-04,
         1.03166051e-01, -1.66242346e-02,  2.62837298e-02,
        -1.33589238e-01,  1.65735006e-01,  3.65820900e-02,
        -1.46895535e-02],
       [-1.61557034e-01,  5.66626638e-02, -1.61597617e-02,
         2.58595943e-02,  3.39905620e-02,  1.01104185e-01,
        -3.71510983e-02,  1.20341092e-01,  3.26242894e-02,
        -4.07250933e-02],
       [-2.17516154e-01,  7.85727724e-02,  9.79433060e-02,
         6.97179586e-02,  4.95264679e-02,  1.92503840e-01,
        -4.96265218e-02,  1.99431688e-01, -5.32730669e-03,
        -2.50038877e-02],
       [-1.35356426e-01, -6.96291253e-02,  3.92658785e-02,
        -9.86322537e-02, -4.20986377e-02,  9.87840891e-02,
         9.67663303e-02,  1.76262826e-01,  9.44406465e-02,
        -2.23472387e-01],
       [-1.25066608e-01,  7.71146417e-02,  4.02672291e-02,
        -2.05352344e-02,  3.11498251e-02,  9.64582711e-02,
        -5.39951548e-02,  2.29750067e-01,  1.61451437e-02,
        -5.41997403e-02],
       [-1.93750665e-01, -3.56721133e-03, -1.50568932e-02,
         1.78796798e-02,  8.33508372e-03, -1.18013099e-02,
        -5.35021350e-02,  2.02244624e-01,  3.02494057e-02,
        -1.20312274e-01],
       [-2.62067527e-01,  2.36408859e-02,  5.58489896e-02,
         1.75756812e-01, -2.75299139e-02,  3.48872915e-02,
         5.41301072e-03,  3.15880209e-01, -5.74782193e-02,
         7.00992346e-03],
       [-2.76674211e-01, -2.08131559e-02, -1.26259401e-02,
         7.77718723e-02, -1.54706314e-01,  1.31996438e-01,
         2.20355690e-02,  5.61908968e-02,  3.73308063e-02,
        -1.17717944e-01],
       [-1.59806639e-01,  1.20503023e-01, -4.36934829e-03,
         1.16428092e-01,  5.47975339e-02,  1.25162587e-01,
         4.78192419e-02,  1.28253624e-01,  7.34245628e-02,
        -1.80039048e-01],
       [-2.67963678e-01,  6.00077920e-02,  1.13472804e-01,
         7.52071738e-02, -6.40357211e-02,  1.03171021e-01,
         1.48901194e-01,  1.97019696e-01,  3.76104042e-02,
        -1.68720663e-01],
       [-2.01240778e-01,  2.47026011e-02,  3.10055390e-02,
        -8.58910009e-03, -8.49897265e-02, -7.54948407e-02,
        -9.39515531e-02,  1.34306327e-01, -1.71037674e-01,
        -5.76597378e-02],
       [-5.20152375e-02,  6.59879148e-02, -3.30656916e-02,
         9.97125208e-02,  3.56362388e-02,  1.26982957e-01,
        -2.69417539e-02,  1.59046397e-01,  1.10872082e-01,
        -1.84650719e-01]], dtype=float32)>}

Ayrıca dağıtılmış bir şekilde yükleyebilir ve çıkarım yapabilirsiniz:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Geri yüklenen işlevi çağırmak, kaydedilen modelde (tahmin) yalnızca ileriye doğru bir geçiştir. Ya yüklü işlevi eğitmeye devam etmek istemezseniz? Veya yüklenen işlevi daha büyük bir modele mi yerleştireceksiniz? Yaygın bir uygulama, bunu başarmak için bu yüklenen nesneyi bir Keras katmanına sarmaktır. Neyse ki, TF Hub'da hub.KerasLayer bu amaç için burada gösterilmiştir:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Epoch 1/2
938/938 [==============================] - 2s 3ms/step - loss: 0.1981 - sparse_categorical_accuracy: 0.9412
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0655 - sparse_categorical_accuracy: 0.9804

Gördüğünüz gibi hub.KerasLayer , tf.saved_model.load() dan geri yüklenen sonucu başka bir model oluşturmak için kullanılabilecek bir tf.saved_model.load() hub.KerasLayer sarar. Bu, transfer öğrenimi için çok kullanışlıdır.

Hangi API'yi kullanmalıyım?

Kaydetmek için, bir keras modeliyle çalışıyorsanız, hemen hemen her zaman model.save() API'sini kullanmanız önerilir. Kaydettiğiniz şey bir Keras modeli değilse, daha düşük seviyeli API tek seçeneğinizdir.

Yükleme için hangi API'yi kullanacağınız, yükleme API'sinden ne almak istediğinize bağlıdır. Bir tf.saved_model.load() modeli alamıyorsanız (veya almak istemiyorsanız), tf.saved_model.load() kullanın. Aksi takdirde, tf.keras.models.load_model() kullanın. Bir Keras modelini yalnızca bir Keras modeli kaydettiyseniz geri alabileceğinizi unutmayın.

API'leri karıştırmak ve eşleştirmek mümkündür. model.save ile bir model.save modelini kaydedebilir ve düşük seviyeli API olan tf.saved_model.load ile model.save olmayan bir modeli tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Yerel cihazdan Kaydetme / Yükleme

Uzaktan çalışırken, örneğin bir bulut TPU kullanarak yerel bir io cihazından kaydederken ve yüklerken, io cihazını localhost olarak ayarlamak için experimental_io_device seçeneği kullanılmalıdır.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Uyarılar

İyi tanımlanmış girdilere sahip olmayan bir Keras modeliniz olması özel bir durumdur. Örneğin, bir Sıralı model herhangi bir girdi şekli olmadan oluşturulabilir ( Sequential([Dense(3), ...] ). Alt sınıflı modeller de başlatmadan sonra iyi tanımlanmış girdilere sahip değildir. Bu durumda, Hem kaydetme hem de yükleme konusunda daha düşük seviyeli API'ler, aksi takdirde bir hata alırsınız.

Modelinizin iyi tanımlanmış girdilere sahip olup olmadığını kontrol etmek için, model.inputs None olup olmadığını kontrol edin. None değilse, hepiniz iyisinizdir. Model .fit , .evaluate , .predict içinde kullanıldığında veya modeli çağırırken ( model(inputs) ) girdi şekilleri otomatik olarak tanımlanır.

İşte bir örnek:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets