このページは Cloud Translation API によって翻訳されました。
Switch to English

分布戦略を使用してモデルを保存およびロードする

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

概観

トレーニング中にモデルを保存して読み込むのが一般的です。 kerasモデルを保存およびロードするためのAPIには、高レベルAPIと低レベルAPIの2つのセットがあります。このチュートリアルでは、使用しているときにSavedModel APIを使用する方法を示しtf.distribute.Strategy 。 SavedModelとシリアル化の概要については、 保存されたモデルのガイドKerasモデルのシリアル化のガイドをご覧ください。簡単な例から始めましょう:

依存関係をインポート:

 import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()
 

tf.distribute.Strategyを使用してデータとモデルを準備します。

 mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model
 
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

モデルをトレーニングする:

 model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
 
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...

Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 5s 5ms/step - loss: 0.2249 - accuracy: 0.9338
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0759 - accuracy: 0.9779

<tensorflow.python.keras.callbacks.History at 0x7f715007afd0>

モデルを保存してロードする

簡単なモデルで作業できるようになったので、保存/読み込みAPIを見てみましょう。使用できるAPIのセットは2つあります。

Keras API

これは、Keras APIを使用してモデルを保存およびロードする例です。

 keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

tf.distribute.Strategyなしでモデルを復元しtf.distribute.Strategy

 restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
 
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0534 - accuracy: 0.9836
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0381 - accuracy: 0.9890

<tensorflow.python.keras.callbacks.History at 0x7f710816ae10>

モデルを復元した後は、保存する前にすでにコンパイルされているため、再度compile()を呼び出す必要がなくても、モデルのトレーニングを続行できます。モデルはTensorFlowの標準のSavedModelプロト形式で保存されます。詳細についてはsaved_model形式のガイドを参照しください。

次に、モデルをロードし、 tf.distribute.Strategyを使用してtf.distribute.Strategyます。

 another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
 
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0537 - accuracy: 0.9837
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0377 - accuracy: 0.9890

ご覧のとおり、ロードはtf.distribute.Strategy期待どおりにtf.distribute.Strategyます。ここで使用する戦略は、保存する前に使用した戦略と同じである必要はありません。

tf.saved_model API

次に、下位レベルのAPIを見てみましょう。モデルの保存は、keras APIに似ています。

 model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
 
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

ロードはtf.saved_model.load()tf.saved_model.load()できます。ただし、これは下位レベルのAPI(したがって、ユースケースの範囲が広い)であるため、Kerasモデルを返しません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。例えば:

 DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
 

ロードされたオブジェクトには、それぞれがキーに関連付けられた複数の関数が含まれている場合があります。 "serving_default"は、保存されたKerasモデルを使用した推論関数のデフォルトキーです。この関数で推論を行うには:

 predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
 
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 1.76999211e-01, -2.32242137e-01,  1.23949878e-01,
        -5.40311933e-02, -8.57487693e-03, -3.96087095e-02,
        -6.73415065e-02,  3.72458547e-01,  3.14344093e-03,
         8.67353380e-03],
       [ 8.22340995e-02, -3.74374330e-01,  3.58670026e-01,
         1.97437629e-02,  9.88980681e-02, -1.26803070e-02,
        -1.29937828e-01,  1.92892700e-01, -1.39045879e-01,
         1.50402993e-01],
       [ 2.17920348e-01, -1.89570293e-01,  1.02963611e-01,
        -1.08369023e-01,  9.65830833e-02,  1.60962120e-02,
        -3.61310542e-02,  2.50176281e-01,  5.14535047e-03,
        -5.93278334e-02],
       [ 1.03556961e-01, -1.02985136e-01,  1.06906675e-01,
        -6.00997955e-02,  8.02036971e-02,  1.95559263e-01,
        -2.45742053e-02,  3.24857533e-01,  6.51413798e-02,
         2.52756067e-02],
       [ 2.19353378e-01, -8.12549293e-02,  4.98566926e-02,
        -5.51203117e-02,  5.38498983e-02, -4.69352081e-02,
        -1.54691160e-01,  1.75368428e-01,  6.56833798e-02,
         4.18215767e-02],
       [ 2.09355026e-01, -3.76172066e-01,  2.30161190e-01,
        -1.00428099e-02,  2.19550565e-01,  7.82390013e-02,
         7.25585222e-03,  3.36478919e-01,  4.12123390e-02,
        -9.27922055e-02],
       [ 1.64225578e-01, -2.48810440e-01,  1.83809400e-01,
         3.82241942e-02,  4.43053246e-02, -7.36974552e-02,
        -1.95371702e-01,  2.86158621e-01, -2.47341514e-01,
         2.24358916e-01],
       [ 1.52286947e-01, -2.09801272e-01,  1.34292826e-01,
         2.15000696e-02,  1.96628809e-01,  1.38852596e-01,
         2.55186632e-02,  2.75483429e-01, -1.25698261e-02,
        -3.66938114e-02],
       [ 1.22778162e-01, -2.86725342e-01,  3.14392090e-01,
         4.92419824e-02,  1.87016353e-01,  8.05974901e-02,
        -1.02200568e-01,  3.40026855e-01, -1.26031846e-01,
         4.29191440e-03],
       [ 1.03287153e-01, -1.44445539e-01,  5.00248559e-02,
        -9.31376815e-02,  1.03142031e-01,  4.27858196e-02,
        -3.59843895e-02,  1.14567459e-01,  3.35859917e-02,
        -4.33023348e-02],
       [ 1.69986919e-01, -1.11736760e-01,  2.45079488e-01,
        -8.76985490e-03,  1.51206836e-01,  1.91091523e-02,
         2.80964151e-02,  3.55648905e-01, -6.07046336e-02,
        -2.16798633e-02],
       [ 8.00362825e-02, -3.77350360e-01,  3.69998842e-01,
        -3.13875452e-02,  1.82057709e-01,  9.40219238e-02,
        -2.71553099e-02,  3.80196393e-01,  1.12584956e-01,
         1.36451364e-01],
       [ 2.87981033e-01, -1.23201102e-01,  1.08305112e-01,
         1.80038624e-02,  1.14839718e-01,  1.32204175e-01,
        -7.62101486e-02,  3.05651367e-01, -1.24888837e-01,
        -1.25689805e-02],
       [ 2.00610712e-01, -1.93711758e-01,  2.13161647e-01,
         2.25381907e-02,  1.98810831e-01,  1.69012442e-01,
        -5.44626266e-02,  3.06417197e-01,  2.34507285e-02,
        -2.22197846e-02],
       [ 4.75327447e-02, -1.49071455e-01,  2.33651549e-01,
        -3.81418206e-02,  1.17788404e-01,  2.41317973e-01,
        -2.84423828e-02,  2.77228653e-01,  1.38170183e-01,
         6.81239665e-02],
       [ 1.47295043e-01, -3.18719327e-01,  1.71222433e-01,
        -2.24875748e-01,  1.64474305e-02,  7.05442131e-02,
        -1.66249961e-01,  2.15151966e-01,  8.68368298e-02,
         5.87985590e-02],
       [ 1.65438741e-01, -1.82458863e-01,  9.23786461e-02,
        -5.69865778e-02, -5.11027575e-02,  1.46423757e-01,
        -1.58534750e-01,  3.62929910e-01, -7.23021924e-02,
         1.98327243e-01],
       [ 1.19980998e-01, -2.52491891e-01,  1.95529804e-01,
        -7.65774697e-02,  2.52251357e-01,  1.13275275e-01,
         5.81757277e-02,  3.78789663e-01,  4.37599421e-02,
         3.16330940e-02],
       [ 1.96890891e-01, -1.64882153e-01,  1.81369573e-01,
         6.73182011e-02,  1.72123373e-01,  2.44809717e-01,
         1.38047934e-02,  3.21171880e-01, -1.63566023e-01,
         5.79664856e-03],
       [ 1.02479883e-01, -1.89321429e-01,  2.06177354e-01,
        -1.80038869e-01,  1.01825431e-01,  1.80727765e-01,
        -5.14207557e-02,  3.10453296e-01,  1.24920391e-01,
         1.31855384e-01],
       [ 1.57777905e-01, -2.01210588e-01,  7.66303241e-02,
        -1.33306794e-02,  5.02551310e-02, -1.82525814e-03,
        -8.92005116e-03,  2.13960752e-01,  8.09747502e-02,
        -8.86394531e-02],
       [ 1.73568279e-01, -1.81509018e-01,  6.94923550e-02,
         2.25923881e-02,  1.25156790e-01,  9.78153944e-02,
         7.18498528e-02,  1.72782511e-01, -1.73949450e-03,
        -3.14355493e-02],
       [ 2.12782592e-01, -1.08876266e-01,  2.08515041e-02,
        -1.01711378e-02,  1.19640127e-01, -9.69958492e-05,
        -6.28693178e-02,  3.33201408e-01, -1.50406286e-01,
        -1.02956995e-01],
       [ 1.26922369e-01, -1.56035721e-01,  2.98747718e-01,
         9.26929712e-02,  8.95737335e-02, -1.29059374e-01,
        -1.43345833e-01,  4.68937457e-02, -4.68213186e-02,
         5.15396222e-02],
       [ 3.52239460e-02, -1.91484004e-01,  2.20030546e-01,
         3.59544829e-02,  1.36512592e-01,  1.21223092e-01,
        -1.34089381e-01,  1.43567622e-01, -3.34661454e-02,
        -5.57777807e-02],
       [ 9.43918601e-02, -1.68835118e-01,  1.62019014e-01,
        -4.19624634e-02,  2.15631202e-01,  1.55934438e-01,
         7.35237896e-02,  5.96011400e-01, -1.44222230e-02,
         1.49046838e-01],
       [ 1.18885040e-01, -1.79240763e-01,  5.50882965e-02,
        -2.13452429e-03,  8.48153755e-02,  3.77807766e-02,
        -1.99006870e-02,  1.78332657e-01,  3.67514491e-02,
        -9.89044830e-03],
       [ 2.14178562e-01, -3.29575062e-01,  1.76141411e-01,
         1.57899737e-01,  4.17742953e-02,  9.08357576e-02,
        -6.54141530e-02,  1.98811680e-01,  4.14663889e-02,
        -7.10944682e-02],
       [ 1.57736331e-01, -1.77458584e-01,  2.13310421e-02,
        -1.05644345e-01,  2.59520113e-02, -1.12300254e-02,
        -1.00170426e-01,  2.14069009e-01,  7.21996576e-02,
        -4.23572361e-02],
       [ 1.21565871e-01, -2.72702247e-01,  2.60294855e-01,
        -5.32292761e-02,  1.61093295e-01,  6.04610592e-02,
         1.44727528e-03,  2.32661918e-01,  1.43471181e-01,
         2.82512531e-02],
       [ 1.60653844e-01, -1.99236870e-01,  1.87010586e-01,
         1.44635066e-02,  1.60880506e-01,  1.58263028e-01,
        -1.78262889e-02,  1.65476620e-01,  4.06849198e-03,
        -5.98017126e-04],
       [ 2.26345137e-01, -1.26628011e-01,  2.68359870e-01,
        -3.30391079e-02,  2.25388691e-01,  2.37583414e-01,
        -6.02742583e-02,  3.29941541e-01, -9.90881920e-02,
        -1.04484018e-02],
       [ 5.37611023e-02, -2.79302955e-01,  2.97504812e-01,
         3.89320739e-02,  1.83170334e-01,  8.92940313e-02,
         8.59095156e-03,  4.23207164e-01, -1.42400682e-01,
         5.53032309e-02],
       [ 1.66490138e-01, -1.48050025e-01,  6.63823634e-02,
         6.58773631e-02,  6.00216947e-02,  8.30844566e-02,
        -1.96889043e-02,  1.42090917e-01, -1.13095082e-01,
        -2.02040598e-02],
       [ 1.97650671e-01, -2.30621904e-01,  2.36031875e-01,
         2.80922949e-02,  1.33660197e-01,  1.84945911e-02,
        -9.49448869e-02,  2.76657969e-01, -1.28546327e-01,
        -4.32698876e-02],
       [ 7.05717355e-02, -2.37958223e-01,  1.16201587e-01,
        -1.77804694e-01, -4.79214638e-02,  5.41470461e-02,
        -8.95375609e-02,  2.57950187e-01,  1.37926370e-01,
         2.06745639e-02],
       [ 1.77722394e-01,  2.01802328e-02,  4.79169115e-02,
         6.72927964e-03,  4.02879566e-02, -1.48054510e-02,
        -3.88961360e-02,  4.56949055e-01, -1.55007973e-01,
         1.00450680e-01],
       [ 9.65217501e-02, -1.62284642e-01,  1.29988074e-01,
         1.91643238e-02,  2.32364126e-02,  1.56593755e-01,
        -2.18545049e-02,  3.09364825e-01,  3.18836495e-02,
         1.42394826e-01],
       [ 1.02766201e-01, -1.55567259e-01,  4.84935120e-02,
        -4.07454185e-02,  5.89885861e-02,  1.24289118e-01,
        -6.47858977e-02,  2.67237425e-01, -7.60506094e-03,
         7.41993338e-02],
       [-1.35931037e-02, -5.91658354e-02,  1.98065460e-01,
        -4.42577340e-02,  2.52601765e-02,  1.36235446e-01,
         5.66409379e-02,  2.78763175e-01,  5.51992878e-02,
         1.38397262e-01],
       [ 2.50516385e-01, -2.32722104e-01,  1.76233470e-01,
        -8.09478611e-02,  2.46379495e-01,  2.21691221e-01,
        -9.12290215e-02,  3.98776174e-01, -2.20650896e-01,
         4.45762612e-02],
       [ 1.62855685e-01, -1.87853783e-01,  1.31980121e-01,
         8.18412453e-02,  8.28907192e-02,  9.36585367e-02,
        -8.11574757e-02,  3.49823534e-01, -6.82802200e-02,
        -6.30170703e-02],
       [ 1.03175372e-01, -2.19012693e-01,  2.57625461e-01,
         6.46688566e-02,  1.53727233e-01,  9.04035568e-02,
         1.08267888e-02,  2.79330760e-01,  1.36587638e-02,
        -1.14830226e-01],
       [ 1.29259855e-01, -1.06586844e-01,  1.17274150e-01,
        -9.82576609e-02,  6.38086647e-02,  6.26218542e-02,
        -2.43740529e-02,  2.77324855e-01,  1.53417945e-01,
        -3.33826877e-02],
       [ 1.77898556e-01, -2.16685712e-01,  1.16354965e-01,
        -1.50649816e-01, -9.64425504e-04, -5.53831831e-03,
        -8.66907164e-02,  1.37190893e-01,  1.07019186e-01,
        -4.00551036e-03],
       [ 6.08430281e-02, -1.38696223e-01,  2.33810246e-01,
        -1.65045038e-02,  1.38037711e-01,  2.39608437e-01,
        -8.14822465e-02,  3.14134359e-01,  2.38902476e-02,
        -3.01111899e-02],
       [ 1.96939170e-01, -2.01352507e-01,  1.55616254e-01,
        -5.67203388e-02,  1.23455018e-01,  9.50125605e-02,
         2.18000263e-02,  2.17068255e-01,  2.51469687e-02,
        -8.46648142e-02],
       [ 1.36755779e-01, -3.81088495e-01,  2.39190102e-01,
        -5.69812208e-03,  1.36373177e-01, -3.45885605e-02,
        -1.08770639e-01,  3.76067132e-01, -1.42801881e-01,
         2.76676923e-01],
       [ 3.44190627e-01, -2.70785093e-01,  1.49080247e-01,
        -4.75045890e-02,  1.41333640e-01,  2.05828235e-01,
        -6.37085736e-03,  3.20268542e-01, -1.26899615e-01,
         1.27643533e-02],
       [ 1.16769515e-01, -1.47749200e-01,  8.63129124e-02,
        -1.42030150e-01,  1.20826051e-01,  1.00516334e-01,
        -7.79214650e-02,  2.25615129e-01,  6.21651560e-02,
        -2.68865749e-02],
       [ 1.37768567e-01, -1.84770569e-01,  2.96809196e-01,
        -1.26516819e-02,  1.80388749e-01, -1.23625688e-01,
        -2.03429088e-02,  1.90774500e-01, -2.17798315e-02,
         9.11172330e-02],
       [ 2.10927024e-01, -1.06398672e-01,  1.27080232e-01,
        -3.35444324e-02,  1.22689918e-01,  8.73885378e-02,
        -4.87845764e-02,  2.51738250e-01, -6.59314170e-02,
         1.95445642e-02],
       [ 1.66006550e-01, -2.38883421e-01,  1.83094591e-01,
        -2.37043053e-02,  1.31914422e-01,  1.39186114e-01,
         3.24706510e-02,  1.58816218e-01, -5.51936850e-02,
        -2.49109045e-02],
       [ 3.82875204e-02, -2.69983470e-01,  1.19026020e-01,
        -3.35979909e-02,  3.90749462e-02,  4.94342819e-02,
         1.83725357e-03,  2.90575564e-01,  1.52164757e-01,
         3.98447663e-02],
       [ 1.50449306e-01, -3.91827106e-01,  1.79179877e-01,
         3.60627249e-02,  3.08906138e-02,  1.09589309e-01,
        -3.67296413e-02,  2.96367109e-01,  1.03148624e-01,
         1.24703869e-02],
       [ 3.67352158e-01, -3.57859850e-01,  2.03363240e-01,
         6.01007305e-02,  1.52023941e-01,  7.79192895e-02,
        -3.06828022e-02,  1.23643860e-01, -7.34634921e-02,
        -1.44566476e-01],
       [ 3.57449591e-01, -2.85736620e-01,  1.44515842e-01,
        -1.21986501e-01,  1.24785170e-01,  4.77472097e-02,
        -2.27646530e-03,  2.92566240e-01, -1.67795032e-01,
         1.31449588e-02],
       [ 1.31427974e-01, -2.56287813e-01,  2.07864910e-01,
        -1.04528628e-02,  1.82650745e-01,  1.84715092e-01,
         2.42921412e-02,  3.46312106e-01, -5.82515635e-02,
        -1.41756684e-02],
       [ 8.51723850e-02, -1.69619843e-01,  1.22124314e-01,
        -4.48531061e-02,  2.77296305e-02,  2.09352121e-01,
        -6.71434253e-02,  3.89958143e-01,  1.05427369e-01,
         5.55033013e-02],
       [ 8.86772648e-02, -2.53054589e-01,  3.05435508e-01,
         8.11745003e-02,  1.35163411e-01,  2.08062962e-01,
        -1.95492104e-01,  1.00858353e-01,  1.84168704e-02,
        -3.98231335e-02],
       [ 2.64127553e-01, -2.30212435e-01,  1.67602122e-01,
         1.05311945e-02,  1.37493700e-01,  2.51933187e-02,
        -2.54113972e-02,  1.19677894e-01, -1.49544716e-01,
        -1.41805455e-01],
       [ 1.17797561e-01, -3.48414183e-01,  2.05935702e-01,
        -2.22041830e-03,  4.93199229e-02,  2.50954702e-02,
        -8.61552358e-02,  3.57540011e-01, -4.01390344e-02,
         7.78207853e-02],
       [ 7.93778002e-02, -1.88805759e-01,  2.48739868e-02,
        -1.04736775e-01, -4.65201214e-02,  4.11739945e-03,
        -4.44301814e-02,  2.69331247e-01, -3.60771306e-02,
         1.67976707e-01],
       [ 2.36172438e-01, -1.68412283e-01,  2.24516869e-01,
         4.47485968e-02,  1.71495676e-01,  2.23281235e-01,
        -3.75738144e-02,  1.78356975e-01, -1.33141026e-01,
        -4.87820171e-02]], dtype=float32)>}

分散した方法でロードして推論を行うこともできます。

 another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
 
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

復元された関数の呼び出しは、保存されたモデル(予測)の単なる転送パスです。ロードされた関数のトレーニングを継続したくない場合はどうなりますか?または、読み込まれた関数をより大きなモデルに埋め込みますか?一般的には、このロードされたオブジェクトをKerasレイヤーにラップしてこれを実現します。幸いにも、 TF Hubにはこの目的のためにhub.KerasLayerがあります。

 import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
 
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 3s 3ms/step - loss: 0.1986 - accuracy: 0.9420
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0655 - accuracy: 0.9797

ご覧のとおり、 hub.KerasLayerは、 hub.KerasLayer tf.saved_model.load()からロードされた結果を、別のモデルの構築に使用できるhub.KerasLayerラップします。これは転移学習に非常に役立ちます。

どのAPIを使用すればよいですか?

保存のために、kerasモデルを使用している場合、 model.save()model.save() APIを使用することをお勧めします。保存するものがKerasモデルでない場合は、下位レベルのAPIが唯一の選択肢です。

ロードの場合、どのAPIを使用するかは、ロードAPIから取得するものによって異なります。 tf.saved_model.load()モデルを取得できない(または取得したくない)場合は、 tf.saved_model.load()使用します。それ以外の場合は、 tf.keras.models.load_model()使用します。 Kerasモデルを保存した場合にのみ、Kerasモデルを取得できることに注意してください。

APIを組み合わせて使用​​することができます。 model.saveモデルをmodel.saveで保存し、 model.save以外のモデルを低レベルAPI tf.saved_model.loadます。

 model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
 
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

注意事項

特別なケースは、明確に定義された入力を持たないKerasモデルがある場合です。たとえば、シーケンシャルモデルは入力形状なしで作成できます( Sequential([Dense(3), ...] )。サブクラス化されたモデルも、初期化後は明確に定義された入力を持ちません。この場合、保存と読み込みの両方で下位レベルのAPIを使用しないと、エラーが発生します。

モデルに明確に定義された入力があるかどうかを確認するには、 model.inputsNoneかどうかを確認します。 Noneでない場合は、問題ありませNone 。入力形状は、モデルが.fit.evaluate.predictで使用されるとき、またはモデルを呼び出すときに自動的に定義されます( model(inputs) )。

次に例を示します。

 class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
 
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f70b00ef898>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f70b00ef898>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f70b00a7470>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f70b00a7470>, because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets