Simpan dan muat model menggunakan strategi distribusi

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Ringkasan

Adalah umum untuk menyimpan dan memuat model selama pelatihan. Ada dua set API untuk menyimpan dan memuat model keras: API tingkat tinggi, dan API tingkat rendah. Tutorial ini menunjukkan bagaimana Anda dapat menggunakan API SavedModel saat menggunakan tf.distribute.Strategy . Untuk mempelajari tentang SavedModel dan serialisasi secara umum, silakan baca save model guide , dan the Keras model serialization guide . Mari kita mulai dengan contoh sederhana:

Impor dependensi:

import tensorflow_datasets as tfds

import tensorflow as tf

Siapkan data dan model menggunakan tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Latih modelnya:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
2022-01-26 05:41:11.916000: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 11s 5ms/step - loss: 0.1873 - sparse_categorical_accuracy: 0.9451
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0641 - sparse_categorical_accuracy: 0.9807
<keras.callbacks.History at 0x7f3b900396d0>

Simpan dan muat modelnya

Sekarang setelah Anda memiliki model sederhana untuk digunakan, mari kita lihat API penyimpanan/pemuatan. Ada dua set API yang tersedia:

API Keras

Berikut adalah contoh menyimpan dan memuat model dengan Keras API:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2022-01-26 05:41:26.593570: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Kembalikan model tanpa tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9859
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0334 - sparse_categorical_accuracy: 0.9895
<keras.callbacks.History at 0x7f3b187b7150>

Setelah memulihkan model, Anda dapat melanjutkan pelatihannya, bahkan tanpa perlu memanggil compile() lagi, karena sudah dikompilasi sebelum disimpan. Model disimpan dalam format proto SavedModel standar TensorFlow. Untuk informasi lebih lanjut, silakan merujuk ke panduan untuk format saved_model .

Sekarang untuk memuat model dan melatihnya menggunakan tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2022-01-26 05:41:33.036733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-01-26 05:41:33.083001: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
938/938 [==============================] - 10s 10ms/step - loss: 0.0474 - sparse_categorical_accuracy: 0.9860
Epoch 2/2
938/938 [==============================] - 10s 10ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9903

Seperti yang Anda lihat, pemuatan berfungsi seperti yang diharapkan dengan tf.distribute.Strategy . Strategi yang digunakan di sini tidak harus sama dengan strategi yang digunakan sebelum menabung.

API tf.saved_model

Sekarang mari kita lihat API level bawah. Menyimpan model mirip dengan API keras:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

Pemuatan dapat dilakukan dengan tf.saved_model.load() . Namun, karena ini adalah API yang berada di level yang lebih rendah (dan karenanya memiliki cakupan kasus penggunaan yang lebih luas), ia tidak mengembalikan model Keras. Sebagai gantinya, ia mengembalikan objek yang berisi fungsi yang dapat digunakan untuk melakukan inferensi. Sebagai contoh:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

Objek yang dimuat mungkin berisi beberapa fungsi, masing-masing terkait dengan kunci. "serving_default" adalah kunci default untuk fungsi inferensi dengan model Keras yang disimpan. Untuk melakukan inferensi dengan fungsi ini:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-1.18789300e-01, -1.78404614e-01,  4.92432676e-02,
        -9.37875658e-02,  1.14302970e-01, -8.99422392e-02,
         9.47709680e-02, -7.75382966e-02,  4.04430032e-02,
         2.41404288e-02],
       [-2.35370561e-01, -3.39397341e-02,  2.73427293e-02,
        -1.08200148e-01,  5.10682352e-02,  1.36142194e-01,
         9.28785652e-02, -5.35808355e-02,  2.56292164e-01,
         1.05301209e-01],
       [-1.91031799e-01, -7.72745535e-02, -7.23153427e-02,
        -1.99329913e-01, -7.45072216e-02,  2.42738128e-02,
         2.07733169e-01, -3.15396488e-03,  4.95976806e-02,
         2.14848563e-01],
       [-9.82482210e-02, -6.13910556e-02,  1.00815810e-01,
        -1.87558904e-01,  1.14685424e-01,  1.53835595e-01,
         1.85714245e-01, -8.74890238e-02,  1.07493028e-01,
         1.57510787e-02],
       [-8.56257528e-02,  3.23683321e-02, -3.66768315e-02,
        -1.47201523e-01, -5.31517603e-02,  1.52744055e-02,
         1.69184029e-01, -5.42814359e-02,  1.11524366e-01,
         5.65215349e-02],
       [-1.50604844e-01, -7.87255913e-03,  1.26651973e-01,
        -1.24476865e-01,  6.94983900e-02,  4.27672639e-03,
         1.86136231e-01, -4.54714149e-03,  9.12746191e-02,
         6.12779632e-02],
       [-2.79157639e-01, -4.61089313e-02,  2.51544192e-02,
        -1.79003477e-01,  3.83432880e-02,  2.05054253e-01,
        -8.25636461e-03, -8.25546682e-03,  2.41342247e-01,
         8.24805871e-02],
       [-1.42795354e-01,  6.54597580e-02,  2.05058958e-02,
        -1.28471941e-01,  1.10977650e-01,  4.51317504e-02,
         2.44124904e-01,  1.90523565e-02,  3.11958641e-02,
         6.49511665e-02],
       [-1.33037239e-01, -2.72594951e-02,  8.09026062e-02,
        -1.95883229e-01,  1.84634060e-01,  1.00822970e-01,
         4.40884084e-02, -6.43826872e-02,  1.47807434e-01,
        -1.92791894e-02],
       [-1.43770471e-01, -2.53150351e-02,  4.18904647e-02,
        -1.02573663e-01,  6.15917407e-02,  7.95702711e-02,
         9.27314460e-02, -4.31537181e-02,  4.59018350e-02,
         1.02965936e-01],
       [-1.90395206e-01,  2.93233991e-03,  1.48900077e-02,
        -1.15877971e-01,  1.06598288e-02,  1.40121073e-01,
         6.86443001e-02, -4.61921766e-02,  1.27470195e-01,
         6.73005953e-02],
       [-2.60747373e-01, -1.45188004e-01,  7.10044056e-04,
        -1.04602516e-01,  5.00324890e-02,  2.96664417e-01,
         8.57191086e-02,  6.65097907e-02,  1.31302923e-01,
        -1.84605196e-02],
       [-1.62942797e-01, -3.63466889e-02, -1.33987352e-01,
        -1.34576231e-01, -8.19503814e-02,  1.30840242e-02,
         6.16783127e-02, -3.64837795e-02,  3.18005830e-02,
         1.98420882e-01],
       [-1.25772715e-01, -6.94367215e-02, -1.35144517e-02,
        -6.30265176e-02,  8.36028308e-02,  2.96559408e-02,
         2.19864860e-01, -7.08417147e-02,  4.76131588e-02,
         1.15781695e-01],
       [-1.55139655e-01, -1.27863720e-01,  9.67459157e-02,
        -1.48635745e-01,  1.25129193e-01,  4.04443927e-02,
         2.94884086e-01, -7.66484886e-02,  1.18753463e-01,
         2.93397382e-02],
       [-1.59221828e-01, -9.30457860e-02,  9.18259323e-02,
        -1.72857821e-01,  8.09611157e-02,  1.11391053e-01,
         1.66679412e-01,  3.52456123e-02,  9.05358568e-02,
         9.89414975e-02],
       [-2.01425552e-01, -4.67008501e-02, -1.62331611e-02,
        -9.73629057e-02,  1.36456266e-01,  1.30628154e-01,
         1.53577864e-01, -6.73157908e-03,  9.31103677e-02,
         1.50734074e-02],
       [-1.29348308e-01, -3.03804129e-03,  2.82487050e-02,
        -2.02886015e-01,  7.09105879e-02,  1.74542382e-01,
         2.57992335e-02, -1.63579211e-02,  2.30892301e-02,
         6.69767857e-02],
       [-1.56857669e-01,  5.46110943e-02, -5.93251809e-02,
        -1.04585059e-01,  2.61763521e-02,  1.43062070e-01,
         1.57771498e-01, -6.19823262e-02,  3.59585434e-02,
         6.62322640e-02],
       [-8.64257440e-02, -1.33483298e-03,  7.46414512e-02,
        -1.82848468e-01,  1.21074423e-01,  1.55276239e-01,
         1.46483868e-01, -6.22515939e-03,  1.91641584e-01,
        -9.95825827e-02],
       [-2.52117336e-01, -6.92471862e-02,  1.09911412e-01,
        -3.73112522e-02,  3.76211852e-03,  5.23591004e-02,
         9.16506499e-02,  6.80204183e-02, -4.27842364e-02,
         7.91264027e-02],
       [-2.11018056e-01,  5.97522780e-03,  8.47486481e-02,
        -7.27925971e-02,  9.36664082e-03,  1.62506998e-01,
         5.32426499e-02,  1.78599171e-02, -2.30420940e-02,
         4.07365486e-02],
       [-1.35342121e-01, -4.06659022e-02, -2.09493563e-02,
        -1.64699793e-01,  8.35808069e-02,  7.68100768e-02,
        -7.14773983e-02, -3.43702435e-02,  9.47649628e-02,
         9.36352089e-02],
       [-1.20486066e-01,  3.77080180e-02,  1.14158325e-01,
        -6.50681928e-02,  1.03382617e-02,  1.17891498e-01,
         1.13154747e-01, -1.49052702e-02,  1.28893867e-01,
         1.12219512e-01],
       [-2.23867983e-01, -9.79400948e-02,  7.37103820e-02,
        -1.05197895e-02,  3.75595838e-02,  1.80490598e-01,
         6.83145374e-02, -3.09509300e-02,  1.42565176e-01,
         8.05927664e-02],
       [-2.32092351e-01, -3.42734642e-02, -5.15977889e-02,
        -1.75458089e-01,  1.46448284e-01,  1.80426955e-01,
         1.52164772e-01, -2.57370695e-02,  1.26812875e-01,
         1.22049123e-01],
       [-9.45013613e-02,  5.85526973e-02,  1.47456676e-02,
        -4.40606587e-02,  4.86647561e-02,  6.28624633e-02,
         3.69989276e-02, -3.68277319e-02,  3.56127135e-02,
         3.10502797e-02],
       [-1.02712311e-01,  3.16979140e-02,  1.88253060e-01,
        -5.99608906e-02,  3.73450294e-02,  6.38176724e-02,
         1.12240583e-01,  2.42183693e-02,  1.45670772e-02,
        -9.52028483e-03],
       [-1.62333213e-02, -1.42737105e-02, -5.79352975e-02,
        -1.01807326e-01, -7.93362781e-03, -7.22003728e-02,
         1.49934232e-01, -1.19943202e-01,  9.22369361e-02,
         1.46321565e-01],
       [-1.32534593e-01,  1.18380897e-02,  2.23980099e-03,
        -9.28303748e-02, -2.20538303e-02,  7.68908709e-02,
         5.29715866e-02, -3.43324393e-02, -1.27909705e-02,
        -7.04141408e-02],
       [-8.10261145e-02, -8.95578321e-03,  3.96864787e-02,
        -1.21861629e-01,  7.98310041e-02,  1.56087667e-01,
         9.11872089e-02, -2.29295418e-02,  5.64432219e-02,
        -3.55931222e-02],
       [-1.76416740e-01,  1.12043694e-02, -1.80068091e-02,
        -1.88012689e-01,  8.68914276e-02,  1.57958359e-01,
         5.77907935e-02, -2.12088451e-02,  5.33877537e-02,
         2.19271183e-02],
       [-2.70012528e-01, -1.26611829e-01,  3.10387388e-02,
        -7.24840909e-02,  1.03253610e-01,  8.91268626e-02,
         1.38662308e-01, -6.25240132e-02,  2.36210316e-01,
         1.40534222e-01],
       [-8.52961093e-02, -1.15273651e-02, -2.88792588e-02,
        -2.01282576e-02,  5.43357767e-02,  7.14191943e-02,
         3.46604213e-02, -6.00920171e-02,  5.11362031e-02,
         3.58160883e-02],
       [-1.63262367e-01,  2.44849995e-02,  3.81964818e-02,
        -3.93010303e-02,  3.95263731e-03,  9.11088511e-02,
         3.88236046e-02,  1.33745335e-02,  1.00076631e-01,
         6.05135933e-02],
       [-3.01809371e-01, -1.58440098e-01,  4.65333983e-02,
        -1.63946241e-01, -6.42775744e-02,  3.93286347e-04,
         2.82839835e-01, -8.93663988e-02,  1.97781295e-01,
         2.87044942e-01],
       [-2.15368003e-01, -4.83291782e-02, -8.29075277e-03,
        -1.01776704e-01,  1.43144801e-02,  1.82002857e-02,
         2.76539754e-02, -1.94141679e-02,  8.87098238e-02,
         6.60644472e-02],
       [-2.20715180e-01, -7.20694065e-02, -6.08972833e-02,
        -4.82957587e-02,  1.28858402e-01,  1.30042464e-01,
         1.32807568e-01, -7.52742141e-02,  9.51702446e-02,
         3.10119465e-02],
       [-1.09407350e-01, -5.27948700e-03,  1.29588693e-03,
        -2.61662379e-02,  3.01920641e-02,  1.13487415e-01,
         8.23267922e-02,  1.92574020e-02,  2.31986474e-02,
         4.13139611e-02],
       [-2.12277412e-01, -1.35507256e-01,  4.22930568e-02,
        -1.34565741e-01,  1.17879853e-01,  1.30573064e-01,
         1.81054786e-01, -1.70722306e-01,  1.05854876e-01,
         7.36362934e-02],
       [-1.78249478e-01, -7.55607188e-02,  7.75147527e-02,
        -2.14659080e-01,  3.26948166e-02,  7.76198730e-02,
         1.08791113e-01, -2.38809325e-02,  1.79410487e-01,
         1.94452941e-01],
       [-1.92162693e-01, -1.50472090e-01, -8.24331492e-02,
        -1.40473023e-02,  3.60646360e-02, -9.39090401e-02,
         1.83859855e-01, -1.09493822e-01, -3.09051797e-02,
         1.36017531e-01],
       [-9.21519399e-02, -1.53335631e-02, -5.56742400e-02,
        -9.68495384e-02,  2.35293470e-02,  2.53665410e-02,
         1.79999322e-01, -7.10204691e-02, -7.29817525e-02,
         4.50368747e-02],
       [-1.22261971e-01, -6.94630146e-02, -7.97796808e-03,
        -1.03088826e-01, -7.38603100e-02,  1.84892826e-02,
         9.76646394e-02, -3.29037756e-02, -1.77134499e-02,
         1.62288889e-01],
       [-6.78652674e-02, -1.08500615e-01,  5.66991530e-02,
        -9.52370912e-02,  5.28126955e-02,  1.05176866e-02,
         1.73085481e-01, -1.37753151e-02,  1.95556954e-02,
         1.38068855e-01],
       [-2.02808753e-01, -3.39423120e-02,  1.82233751e-03,
        -5.71424365e-02,  3.40205729e-02,  8.74454305e-02,
         8.47227685e-03, -2.52498202e-02,  4.66104299e-02,
         1.10718749e-01],
       [-9.52449068e-02, -3.35062481e-02, -1.00178778e-01,
        -9.72513855e-02, -3.58061343e-02,  3.04423086e-02,
         5.70362583e-02, -4.03833576e-02, -4.28436548e-02,
         9.73245874e-02],
       [-2.06081957e-01, -1.71493232e-01,  2.52560824e-02,
        -1.55212343e-01, -4.33478206e-02,  2.34177694e-01,
         8.46128762e-02,  1.75322518e-02,  2.04347119e-01,
         1.54971585e-01],
       [-1.95310384e-01,  1.30968075e-02, -9.68117267e-03,
        -7.31432810e-02,  1.02618083e-01,  1.59629256e-01,
         1.66028887e-01, -7.12903216e-03,  1.78021699e-01,
        -2.17130631e-02],
       [-1.59163624e-01, -1.77137554e-05,  1.75410658e-02,
        -9.08103511e-02,  7.25786015e-02,  9.21041369e-02,
         1.24915361e-01, -6.55939505e-02, -1.13440230e-02,
         1.03661232e-01],
       [-1.93366870e-01, -4.36344892e-02,  1.37750164e-01,
        -1.91939399e-01, -1.50268525e-03,  8.03942382e-02,
         2.15812266e-01,  5.38492575e-02,  1.36685073e-01,
         2.22119391e-01],
       [-1.65946245e-01,  7.89588690e-03, -1.65037125e-01,
        -1.23690292e-01, -8.57629776e-02, -2.55736727e-02,
         1.67541012e-01, -6.63827211e-02,  2.98694819e-02,
         1.71927184e-01],
       [-1.56264767e-01, -1.72245800e-02, -4.98924702e-02,
        -2.98387632e-02,  2.80477256e-02,  4.94132042e-02,
         4.89805043e-02,  1.96998678e-02, -4.14144360e-02,
        -5.05549274e-02],
       [-1.46449029e-01, -1.12528354e-01, -4.66653258e-02,
        -3.78398523e-02,  7.60737807e-03, -2.70657167e-02,
         1.11277811e-01,  6.37479573e-02, -2.39458829e-02,
         1.22067556e-01],
       [-1.92323536e-01, -1.43002480e-01,  5.29062748e-03,
        -1.70663983e-01,  8.39572400e-03,  6.37906119e-02,
         1.24084033e-01,  6.02792688e-02,  7.18353763e-02,
         5.03963791e-03],
       [-1.70977920e-01,  1.04207098e-02,  1.18544906e-01,
        -4.29532528e-02, -3.53983864e-02,  1.80302024e-01,
         8.08775946e-02,  3.19045782e-02,  2.52931342e-02,
         1.29424319e-01],
       [-2.13301033e-01, -6.96119964e-02,  2.32847631e-02,
        -7.73920864e-02,  1.10387571e-01,  1.13307782e-01,
         1.41805351e-01, -5.19381016e-02,  1.15313083e-01,
         1.40049949e-01],
       [-1.71651557e-01, -5.98860830e-02, -3.92800570e-03,
        -1.04376137e-01,  7.78115019e-02,  6.84583709e-02,
         2.51923770e-01, -1.05199262e-01,  1.64517179e-01,
         2.18875334e-01],
       [-2.60777414e-01, -8.93031508e-02,  1.27723843e-01,
        -1.97950065e-01,  1.19145498e-01,  7.30907321e-02,
         2.23771721e-01, -6.83849230e-02,  3.68930906e-01,
         1.86811388e-01],
       [-2.38028213e-01,  1.11199915e-03,  2.25015372e-01,
         8.22724327e-02, -1.14511400e-01,  1.57513067e-01,
         5.22858277e-02,  2.13724375e-03,  3.15639377e-02,
         2.08704025e-01],
       [-1.46687120e-01, -1.10313833e-01, -1.16352811e-02,
        -1.44550815e-01,  2.09794566e-02,  1.47883072e-02,
         3.96856442e-02, -2.15019658e-03, -4.90810722e-02,
         1.34708211e-01],
       [-2.02591017e-01, -2.29728431e-01,  6.73423260e-02,
        -1.24901496e-01, -1.38434023e-02,  8.64367038e-02,
         1.22342721e-01,  1.67826824e-02,  1.65354639e-01,
         1.83434993e-01],
       [-2.25799978e-01, -1.02682747e-01,  9.48531851e-02,
        -9.38871950e-02,  1.03806734e-01,  2.04695478e-01,
         8.09893832e-02, -1.45416632e-02,  1.33486420e-01,
        -6.27665371e-02],
       [-1.19375348e-01,  2.23235339e-02,  1.04302749e-01,
        -1.11149743e-01,  6.12434298e-02,  6.89433664e-02,
         2.08741099e-01, -3.81497070e-02, -1.42122135e-02,
         7.65201449e-03]], dtype=float32)>}
2022-01-26 05:41:53.590742: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Anda juga dapat memuat dan melakukan inferensi secara terdistribusi:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2022-01-26 05:41:53.931428: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Memanggil fungsi yang dipulihkan hanyalah penerusan dari model yang disimpan (prediksi). Bagaimana jika Anda ingin melanjutkan pelatihan fungsi yang dimuat? Atau menyematkan fungsi yang dimuat ke dalam model yang lebih besar? Praktik umum adalah membungkus objek yang dimuat ini ke lapisan Keras untuk mencapai ini. Untungnya, TF Hub memiliki hub.KerasLayer untuk tujuan ini, ditunjukkan di sini:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
2022-01-26 05:41:55.594317: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
938/938 [==============================] - 6s 3ms/step - loss: 0.1910 - sparse_categorical_accuracy: 0.9442
Epoch 2/2
938/938 [==============================] - 3s 4ms/step - loss: 0.0633 - sparse_categorical_accuracy: 0.9813

Seperti yang Anda lihat, hub.KerasLayer membungkus hasil yang dimuat kembali dari tf.saved_model.load() ke dalam lapisan Keras yang dapat digunakan untuk membangun model lain. Hal ini sangat berguna untuk transfer pembelajaran.

API mana yang harus saya gunakan?

Untuk menyimpan, jika Anda bekerja dengan model keras, hampir selalu disarankan untuk menggunakan API model.save() Keras. Jika yang Anda simpan bukan model Keras, maka API level bawah adalah satu-satunya pilihan Anda.

Untuk memuat, API mana yang Anda gunakan bergantung pada apa yang ingin Anda dapatkan dari API pemuatan. Jika Anda tidak dapat (atau tidak ingin) mendapatkan model Keras, gunakan tf.saved_model.load() . Jika tidak, gunakan tf.keras.models.load_model() . Perhatikan bahwa Anda bisa mendapatkan kembali model Keras hanya jika Anda menyimpan model Keras.

Dimungkinkan untuk mencampur dan mencocokkan API. Anda dapat menyimpan model Keras dengan model.save , dan memuat model non-Keras dengan API tingkat rendah, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Menyimpan/Memuat dari perangkat lokal

Saat menyimpan dan memuat dari perangkat io lokal saat berjalan dari jarak jauh, misalnya menggunakan cloud TPU, opsi experimental_io_device harus digunakan untuk mengatur perangkat io ke localhost.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Peringatan

Kasus khusus adalah ketika Anda memiliki model Keras yang tidak memiliki input yang terdefinisi dengan baik. Misalnya, model Sequential dapat dibuat tanpa bentuk input apa pun ( Sequential([Dense(3), ...] ). Model subclass juga tidak memiliki input yang terdefinisi dengan baik setelah inisialisasi. Dalam hal ini, Anda harus tetap menggunakan API tingkat yang lebih rendah pada penyimpanan dan pemuatan, jika tidak, Anda akan mendapatkan kesalahan.

Untuk memeriksa apakah model Anda memiliki input yang terdefinisi dengan baik, cukup periksa apakah model.inputs adalah None . Jika bukan None , Anda semua baik-baik saja. Bentuk input ditentukan secara otomatis saat model digunakan di .fit , .evaluate , .predict , atau saat memanggil model ( model(inputs) ).

Berikut ini contohnya:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets