Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Zapisz i załaduj model przy użyciu strategii dystrybucji

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Przegląd

Często zapisuje się i ładuje model podczas treningu. Istnieją dwa zestawy interfejsów API do zapisywania i ładowania modelu keras: interfejs API wysokiego poziomu i interfejs API niskiego poziomu. W tym samouczku pokazano, jak można używać interfejsów API SavedModel podczas korzystania z tf.distribute.Strategy . Aby dowiedzieć się więcej o SavedModel i serializacji ogólnie, przeczytaj przewodnik po zapisanym modelu i przewodnik serializacji modelu Keras . Zacznijmy od prostego przykładu:

Zależności importu:

 import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()
 

Przygotuj dane i model za pomocą tf.distribute.Strategy :

 mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model
 
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Wytrenuj model:

 model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
 
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...

Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 5s 5ms/step - loss: 0.2249 - accuracy: 0.9338
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0759 - accuracy: 0.9779

<tensorflow.python.keras.callbacks.History at 0x7f715007afd0>

Zapisz i wczytaj model

Teraz, gdy masz już prosty model do pracy, spójrzmy na zapisywanie / ładowanie interfejsów API. Dostępne są dwa zestawy interfejsów API:

Interfejsy API Keras

Oto przykład zapisywania i ładowania modelu za pomocą interfejsów API Keras:

 keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Przywróć model bez tf.distribute.Strategy :

 restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
 
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0534 - accuracy: 0.9836
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0381 - accuracy: 0.9890

<tensorflow.python.keras.callbacks.History at 0x7f710816ae10>

Po przywróceniu modelu można kontynuować jego szkolenie, nawet bez konieczności ponownego wywoływania funkcji compile() , ponieważ jest on już skompilowany przed zapisaniem. Model jest zapisywany w standardowym formacie protokołu SavedModel . Aby uzyskać więcej informacji, zapoznaj się z przewodnikiem po formacie saved_model .

Teraz załaduj model i tf.distribute.Strategy go przy użyciu tf.distribute.Strategy :

 another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
 
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0537 - accuracy: 0.9837
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0377 - accuracy: 0.9890

Jak widać, ładowanie działa zgodnie z oczekiwaniami z tf.distribute.Strategy . Strategia zastosowana tutaj nie musi być tą samą strategią, która została użyta przed zapisaniem.

tf.saved_model API tf.saved_model

Przyjrzyjmy się teraz interfejsom API niższego poziomu. Zapisywanie modelu przebiega podobnie jak w keras API:

 model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
 
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

Ładowanie można wykonać za pomocą tf.saved_model.load() . Jednak ponieważ jest to interfejs API, który znajduje się na niższym poziomie (i stąd ma szerszy zakres zastosowań), nie zwraca modelu Keras. Zamiast tego zwraca obiekt zawierający funkcje, których można użyć do wnioskowania. Na przykład:

 DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
 

Załadowany obiekt może zawierać wiele funkcji, z których każda jest skojarzona z klawiszem. "serving_default" jest domyślnym kluczem funkcji wnioskowania z zapisanym modelem Keras. Aby wykonać wnioskowanie za pomocą tej funkcji:

 predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
 
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 1.76999211e-01, -2.32242137e-01,  1.23949878e-01,
        -5.40311933e-02, -8.57487693e-03, -3.96087095e-02,
        -6.73415065e-02,  3.72458547e-01,  3.14344093e-03,
         8.67353380e-03],
       [ 8.22340995e-02, -3.74374330e-01,  3.58670026e-01,
         1.97437629e-02,  9.88980681e-02, -1.26803070e-02,
        -1.29937828e-01,  1.92892700e-01, -1.39045879e-01,
         1.50402993e-01],
       [ 2.17920348e-01, -1.89570293e-01,  1.02963611e-01,
        -1.08369023e-01,  9.65830833e-02,  1.60962120e-02,
        -3.61310542e-02,  2.50176281e-01,  5.14535047e-03,
        -5.93278334e-02],
       [ 1.03556961e-01, -1.02985136e-01,  1.06906675e-01,
        -6.00997955e-02,  8.02036971e-02,  1.95559263e-01,
        -2.45742053e-02,  3.24857533e-01,  6.51413798e-02,
         2.52756067e-02],
       [ 2.19353378e-01, -8.12549293e-02,  4.98566926e-02,
        -5.51203117e-02,  5.38498983e-02, -4.69352081e-02,
        -1.54691160e-01,  1.75368428e-01,  6.56833798e-02,
         4.18215767e-02],
       [ 2.09355026e-01, -3.76172066e-01,  2.30161190e-01,
        -1.00428099e-02,  2.19550565e-01,  7.82390013e-02,
         7.25585222e-03,  3.36478919e-01,  4.12123390e-02,
        -9.27922055e-02],
       [ 1.64225578e-01, -2.48810440e-01,  1.83809400e-01,
         3.82241942e-02,  4.43053246e-02, -7.36974552e-02,
        -1.95371702e-01,  2.86158621e-01, -2.47341514e-01,
         2.24358916e-01],
       [ 1.52286947e-01, -2.09801272e-01,  1.34292826e-01,
         2.15000696e-02,  1.96628809e-01,  1.38852596e-01,
         2.55186632e-02,  2.75483429e-01, -1.25698261e-02,
        -3.66938114e-02],
       [ 1.22778162e-01, -2.86725342e-01,  3.14392090e-01,
         4.92419824e-02,  1.87016353e-01,  8.05974901e-02,
        -1.02200568e-01,  3.40026855e-01, -1.26031846e-01,
         4.29191440e-03],
       [ 1.03287153e-01, -1.44445539e-01,  5.00248559e-02,
        -9.31376815e-02,  1.03142031e-01,  4.27858196e-02,
        -3.59843895e-02,  1.14567459e-01,  3.35859917e-02,
        -4.33023348e-02],
       [ 1.69986919e-01, -1.11736760e-01,  2.45079488e-01,
        -8.76985490e-03,  1.51206836e-01,  1.91091523e-02,
         2.80964151e-02,  3.55648905e-01, -6.07046336e-02,
        -2.16798633e-02],
       [ 8.00362825e-02, -3.77350360e-01,  3.69998842e-01,
        -3.13875452e-02,  1.82057709e-01,  9.40219238e-02,
        -2.71553099e-02,  3.80196393e-01,  1.12584956e-01,
         1.36451364e-01],
       [ 2.87981033e-01, -1.23201102e-01,  1.08305112e-01,
         1.80038624e-02,  1.14839718e-01,  1.32204175e-01,
        -7.62101486e-02,  3.05651367e-01, -1.24888837e-01,
        -1.25689805e-02],
       [ 2.00610712e-01, -1.93711758e-01,  2.13161647e-01,
         2.25381907e-02,  1.98810831e-01,  1.69012442e-01,
        -5.44626266e-02,  3.06417197e-01,  2.34507285e-02,
        -2.22197846e-02],
       [ 4.75327447e-02, -1.49071455e-01,  2.33651549e-01,
        -3.81418206e-02,  1.17788404e-01,  2.41317973e-01,
        -2.84423828e-02,  2.77228653e-01,  1.38170183e-01,
         6.81239665e-02],
       [ 1.47295043e-01, -3.18719327e-01,  1.71222433e-01,
        -2.24875748e-01,  1.64474305e-02,  7.05442131e-02,
        -1.66249961e-01,  2.15151966e-01,  8.68368298e-02,
         5.87985590e-02],
       [ 1.65438741e-01, -1.82458863e-01,  9.23786461e-02,
        -5.69865778e-02, -5.11027575e-02,  1.46423757e-01,
        -1.58534750e-01,  3.62929910e-01, -7.23021924e-02,
         1.98327243e-01],
       [ 1.19980998e-01, -2.52491891e-01,  1.95529804e-01,
        -7.65774697e-02,  2.52251357e-01,  1.13275275e-01,
         5.81757277e-02,  3.78789663e-01,  4.37599421e-02,
         3.16330940e-02],
       [ 1.96890891e-01, -1.64882153e-01,  1.81369573e-01,
         6.73182011e-02,  1.72123373e-01,  2.44809717e-01,
         1.38047934e-02,  3.21171880e-01, -1.63566023e-01,
         5.79664856e-03],
       [ 1.02479883e-01, -1.89321429e-01,  2.06177354e-01,
        -1.80038869e-01,  1.01825431e-01,  1.80727765e-01,
        -5.14207557e-02,  3.10453296e-01,  1.24920391e-01,
         1.31855384e-01],
       [ 1.57777905e-01, -2.01210588e-01,  7.66303241e-02,
        -1.33306794e-02,  5.02551310e-02, -1.82525814e-03,
        -8.92005116e-03,  2.13960752e-01,  8.09747502e-02,
        -8.86394531e-02],
       [ 1.73568279e-01, -1.81509018e-01,  6.94923550e-02,
         2.25923881e-02,  1.25156790e-01,  9.78153944e-02,
         7.18498528e-02,  1.72782511e-01, -1.73949450e-03,
        -3.14355493e-02],
       [ 2.12782592e-01, -1.08876266e-01,  2.08515041e-02,
        -1.01711378e-02,  1.19640127e-01, -9.69958492e-05,
        -6.28693178e-02,  3.33201408e-01, -1.50406286e-01,
        -1.02956995e-01],
       [ 1.26922369e-01, -1.56035721e-01,  2.98747718e-01,
         9.26929712e-02,  8.95737335e-02, -1.29059374e-01,
        -1.43345833e-01,  4.68937457e-02, -4.68213186e-02,
         5.15396222e-02],
       [ 3.52239460e-02, -1.91484004e-01,  2.20030546e-01,
         3.59544829e-02,  1.36512592e-01,  1.21223092e-01,
        -1.34089381e-01,  1.43567622e-01, -3.34661454e-02,
        -5.57777807e-02],
       [ 9.43918601e-02, -1.68835118e-01,  1.62019014e-01,
        -4.19624634e-02,  2.15631202e-01,  1.55934438e-01,
         7.35237896e-02,  5.96011400e-01, -1.44222230e-02,
         1.49046838e-01],
       [ 1.18885040e-01, -1.79240763e-01,  5.50882965e-02,
        -2.13452429e-03,  8.48153755e-02,  3.77807766e-02,
        -1.99006870e-02,  1.78332657e-01,  3.67514491e-02,
        -9.89044830e-03],
       [ 2.14178562e-01, -3.29575062e-01,  1.76141411e-01,
         1.57899737e-01,  4.17742953e-02,  9.08357576e-02,
        -6.54141530e-02,  1.98811680e-01,  4.14663889e-02,
        -7.10944682e-02],
       [ 1.57736331e-01, -1.77458584e-01,  2.13310421e-02,
        -1.05644345e-01,  2.59520113e-02, -1.12300254e-02,
        -1.00170426e-01,  2.14069009e-01,  7.21996576e-02,
        -4.23572361e-02],
       [ 1.21565871e-01, -2.72702247e-01,  2.60294855e-01,
        -5.32292761e-02,  1.61093295e-01,  6.04610592e-02,
         1.44727528e-03,  2.32661918e-01,  1.43471181e-01,
         2.82512531e-02],
       [ 1.60653844e-01, -1.99236870e-01,  1.87010586e-01,
         1.44635066e-02,  1.60880506e-01,  1.58263028e-01,
        -1.78262889e-02,  1.65476620e-01,  4.06849198e-03,
        -5.98017126e-04],
       [ 2.26345137e-01, -1.26628011e-01,  2.68359870e-01,
        -3.30391079e-02,  2.25388691e-01,  2.37583414e-01,
        -6.02742583e-02,  3.29941541e-01, -9.90881920e-02,
        -1.04484018e-02],
       [ 5.37611023e-02, -2.79302955e-01,  2.97504812e-01,
         3.89320739e-02,  1.83170334e-01,  8.92940313e-02,
         8.59095156e-03,  4.23207164e-01, -1.42400682e-01,
         5.53032309e-02],
       [ 1.66490138e-01, -1.48050025e-01,  6.63823634e-02,
         6.58773631e-02,  6.00216947e-02,  8.30844566e-02,
        -1.96889043e-02,  1.42090917e-01, -1.13095082e-01,
        -2.02040598e-02],
       [ 1.97650671e-01, -2.30621904e-01,  2.36031875e-01,
         2.80922949e-02,  1.33660197e-01,  1.84945911e-02,
        -9.49448869e-02,  2.76657969e-01, -1.28546327e-01,
        -4.32698876e-02],
       [ 7.05717355e-02, -2.37958223e-01,  1.16201587e-01,
        -1.77804694e-01, -4.79214638e-02,  5.41470461e-02,
        -8.95375609e-02,  2.57950187e-01,  1.37926370e-01,
         2.06745639e-02],
       [ 1.77722394e-01,  2.01802328e-02,  4.79169115e-02,
         6.72927964e-03,  4.02879566e-02, -1.48054510e-02,
        -3.88961360e-02,  4.56949055e-01, -1.55007973e-01,
         1.00450680e-01],
       [ 9.65217501e-02, -1.62284642e-01,  1.29988074e-01,
         1.91643238e-02,  2.32364126e-02,  1.56593755e-01,
        -2.18545049e-02,  3.09364825e-01,  3.18836495e-02,
         1.42394826e-01],
       [ 1.02766201e-01, -1.55567259e-01,  4.84935120e-02,
        -4.07454185e-02,  5.89885861e-02,  1.24289118e-01,
        -6.47858977e-02,  2.67237425e-01, -7.60506094e-03,
         7.41993338e-02],
       [-1.35931037e-02, -5.91658354e-02,  1.98065460e-01,
        -4.42577340e-02,  2.52601765e-02,  1.36235446e-01,
         5.66409379e-02,  2.78763175e-01,  5.51992878e-02,
         1.38397262e-01],
       [ 2.50516385e-01, -2.32722104e-01,  1.76233470e-01,
        -8.09478611e-02,  2.46379495e-01,  2.21691221e-01,
        -9.12290215e-02,  3.98776174e-01, -2.20650896e-01,
         4.45762612e-02],
       [ 1.62855685e-01, -1.87853783e-01,  1.31980121e-01,
         8.18412453e-02,  8.28907192e-02,  9.36585367e-02,
        -8.11574757e-02,  3.49823534e-01, -6.82802200e-02,
        -6.30170703e-02],
       [ 1.03175372e-01, -2.19012693e-01,  2.57625461e-01,
         6.46688566e-02,  1.53727233e-01,  9.04035568e-02,
         1.08267888e-02,  2.79330760e-01,  1.36587638e-02,
        -1.14830226e-01],
       [ 1.29259855e-01, -1.06586844e-01,  1.17274150e-01,
        -9.82576609e-02,  6.38086647e-02,  6.26218542e-02,
        -2.43740529e-02,  2.77324855e-01,  1.53417945e-01,
        -3.33826877e-02],
       [ 1.77898556e-01, -2.16685712e-01,  1.16354965e-01,
        -1.50649816e-01, -9.64425504e-04, -5.53831831e-03,
        -8.66907164e-02,  1.37190893e-01,  1.07019186e-01,
        -4.00551036e-03],
       [ 6.08430281e-02, -1.38696223e-01,  2.33810246e-01,
        -1.65045038e-02,  1.38037711e-01,  2.39608437e-01,
        -8.14822465e-02,  3.14134359e-01,  2.38902476e-02,
        -3.01111899e-02],
       [ 1.96939170e-01, -2.01352507e-01,  1.55616254e-01,
        -5.67203388e-02,  1.23455018e-01,  9.50125605e-02,
         2.18000263e-02,  2.17068255e-01,  2.51469687e-02,
        -8.46648142e-02],
       [ 1.36755779e-01, -3.81088495e-01,  2.39190102e-01,
        -5.69812208e-03,  1.36373177e-01, -3.45885605e-02,
        -1.08770639e-01,  3.76067132e-01, -1.42801881e-01,
         2.76676923e-01],
       [ 3.44190627e-01, -2.70785093e-01,  1.49080247e-01,
        -4.75045890e-02,  1.41333640e-01,  2.05828235e-01,
        -6.37085736e-03,  3.20268542e-01, -1.26899615e-01,
         1.27643533e-02],
       [ 1.16769515e-01, -1.47749200e-01,  8.63129124e-02,
        -1.42030150e-01,  1.20826051e-01,  1.00516334e-01,
        -7.79214650e-02,  2.25615129e-01,  6.21651560e-02,
        -2.68865749e-02],
       [ 1.37768567e-01, -1.84770569e-01,  2.96809196e-01,
        -1.26516819e-02,  1.80388749e-01, -1.23625688e-01,
        -2.03429088e-02,  1.90774500e-01, -2.17798315e-02,
         9.11172330e-02],
       [ 2.10927024e-01, -1.06398672e-01,  1.27080232e-01,
        -3.35444324e-02,  1.22689918e-01,  8.73885378e-02,
        -4.87845764e-02,  2.51738250e-01, -6.59314170e-02,
         1.95445642e-02],
       [ 1.66006550e-01, -2.38883421e-01,  1.83094591e-01,
        -2.37043053e-02,  1.31914422e-01,  1.39186114e-01,
         3.24706510e-02,  1.58816218e-01, -5.51936850e-02,
        -2.49109045e-02],
       [ 3.82875204e-02, -2.69983470e-01,  1.19026020e-01,
        -3.35979909e-02,  3.90749462e-02,  4.94342819e-02,
         1.83725357e-03,  2.90575564e-01,  1.52164757e-01,
         3.98447663e-02],
       [ 1.50449306e-01, -3.91827106e-01,  1.79179877e-01,
         3.60627249e-02,  3.08906138e-02,  1.09589309e-01,
        -3.67296413e-02,  2.96367109e-01,  1.03148624e-01,
         1.24703869e-02],
       [ 3.67352158e-01, -3.57859850e-01,  2.03363240e-01,
         6.01007305e-02,  1.52023941e-01,  7.79192895e-02,
        -3.06828022e-02,  1.23643860e-01, -7.34634921e-02,
        -1.44566476e-01],
       [ 3.57449591e-01, -2.85736620e-01,  1.44515842e-01,
        -1.21986501e-01,  1.24785170e-01,  4.77472097e-02,
        -2.27646530e-03,  2.92566240e-01, -1.67795032e-01,
         1.31449588e-02],
       [ 1.31427974e-01, -2.56287813e-01,  2.07864910e-01,
        -1.04528628e-02,  1.82650745e-01,  1.84715092e-01,
         2.42921412e-02,  3.46312106e-01, -5.82515635e-02,
        -1.41756684e-02],
       [ 8.51723850e-02, -1.69619843e-01,  1.22124314e-01,
        -4.48531061e-02,  2.77296305e-02,  2.09352121e-01,
        -6.71434253e-02,  3.89958143e-01,  1.05427369e-01,
         5.55033013e-02],
       [ 8.86772648e-02, -2.53054589e-01,  3.05435508e-01,
         8.11745003e-02,  1.35163411e-01,  2.08062962e-01,
        -1.95492104e-01,  1.00858353e-01,  1.84168704e-02,
        -3.98231335e-02],
       [ 2.64127553e-01, -2.30212435e-01,  1.67602122e-01,
         1.05311945e-02,  1.37493700e-01,  2.51933187e-02,
        -2.54113972e-02,  1.19677894e-01, -1.49544716e-01,
        -1.41805455e-01],
       [ 1.17797561e-01, -3.48414183e-01,  2.05935702e-01,
        -2.22041830e-03,  4.93199229e-02,  2.50954702e-02,
        -8.61552358e-02,  3.57540011e-01, -4.01390344e-02,
         7.78207853e-02],
       [ 7.93778002e-02, -1.88805759e-01,  2.48739868e-02,
        -1.04736775e-01, -4.65201214e-02,  4.11739945e-03,
        -4.44301814e-02,  2.69331247e-01, -3.60771306e-02,
         1.67976707e-01],
       [ 2.36172438e-01, -1.68412283e-01,  2.24516869e-01,
         4.47485968e-02,  1.71495676e-01,  2.23281235e-01,
        -3.75738144e-02,  1.78356975e-01, -1.33141026e-01,
        -4.87820171e-02]], dtype=float32)>}

Możesz również ładować i przeprowadzać wnioskowanie w sposób rozproszony:

 another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
 
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Wywołanie przywróconej funkcji jest po prostu przekazaniem do przodu zapisanego modelu (przewidywanie). A co jeśli nie chcesz kontynuować trenowania wczytanej funkcji? Lub osadzić załadowaną funkcję w większym modelu? Aby to osiągnąć, powszechną praktyką jest umieszczenie tego załadowanego obiektu w warstwie Keras. Na szczęście TF Hub ma do tego celu hub.KerasLayer , pokazany tutaj:

 import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
 
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 3s 3ms/step - loss: 0.1986 - accuracy: 0.9420
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0655 - accuracy: 0.9797

Jak widać, hub.KerasLayer wynik załadowany z powrotem z tf.saved_model.load() do warstwy Keras, której można użyć do zbudowania innego modelu. Jest to bardzo przydatne do transferu uczenia się.

Którego interfejsu API powinienem użyć?

Do zapisywania, jeśli pracujesz z modelem keras, prawie zawsze zaleca się użycie interfejsu API model.save() programu Keras. Jeśli to, co zapisujesz, nie jest modelem Keras, jedynym wyborem jest interfejs API niższego poziomu.

W przypadku ładowania, którego API używasz, zależy od tego, co chcesz uzyskać z ładującego API. Jeśli nie możesz (lub nie chcesz) uzyskać modelu Keras, użyj tf.saved_model.load() . W przeciwnym razie użyj tf.keras.models.load_model() . Pamiętaj, że możesz odzyskać model Keras tylko wtedy, gdy zapisałeś model Keras.

Możliwe jest mieszanie i dopasowywanie interfejsów API. Możesz zapisać model Keras za pomocą model.save i załadować model inny niż Keras za pomocą niskopoziomowego interfejsu API, tf.saved_model.load .

 model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
 
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Ostrzeżenia

Szczególnym przypadkiem jest sytuacja, gdy masz model Keras, który nie ma dobrze zdefiniowanych danych wejściowych. Na przykład model sekwencyjny można utworzyć bez żadnych kształtów wejściowych ( Sequential([Dense(3), ...] ). Modele podklasowe również nie mają dobrze zdefiniowanych danych wejściowych po inicjalizacji. W takim przypadku należy trzymać się API niższego poziomu podczas zapisywania i ładowania, w przeciwnym razie pojawi się błąd.

Aby sprawdzić, czy model ma dobrze zdefiniowane dane wejściowe, po prostu sprawdź, czy model.inputs ma wartość None . Jeśli nie jest None , wszyscy jesteście dobrzy. Kształty wejściowe są automatycznie zdefiniowane kiedy model jest stosowany w .fit , .evaluate , .predict lub podczas wywoływania modelu ( model(inputs) ).

Oto przykład:

 class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
 
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f70b00ef898>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f70b00ef898>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f70b00a7470>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f70b00a7470>, because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets