Google I / Oの基調講演、製品セッション、ワークショップなどを見るプレイリストを見る

分散ストラテジーを使ってモデルを保存して読み込む

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

トレーニング中にモデルを保存して読み込むことは一般的な作業です。Keras モデルの保存と読み込みに使用する API には、高レベル API と低レベル API の 2 つがあります。このチュートリアルでは、tf.distribute.Strategy を使用してる場合に SavedModel API を使用する方法を実演しています。SavedModel とシリアル化に関する一般的な情報は、SavedModel ガイドKeras モデルのシリアル化ガイドをご覧ください。では、簡単な例から始めましょう。

依存関係をインポートします。

import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()

tf.distribute.Strategy を使ってデータとモデルを準備します。

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

モデルをトレーニングします。

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 10s 4ms/step - loss: 0.4205 - accuracy: 0.8744
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0789 - accuracy: 0.9769
<tensorflow.python.keras.callbacks.History at 0x7fa340042828>

モデルを保存して読み込む

作業に使用する単純なモデルの準備ができたので、保存と読み込みの API を見てみましょう。利用できる API には次の 2 種類があります。

Keras API

次に、Keras API を使ってモデルを保存して読み込む例を示します。

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

tf.distribute.Strategy を使用せずにモデルを復元します。

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 3s 2ms/step - loss: 0.0510 - accuracy: 0.9843
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0355 - accuracy: 0.9893
<tensorflow.python.keras.callbacks.History at 0x7fa3c3374fd0>

モデルを復元したら、compile() をもう一度呼び出すことなくそのトレーニングを続行できます。保存前にコンパイル済みであるからです。モデルは、TensorFlow の標準的な SavedModel プロと形式で保存されています。その他の詳細は、saved_model 形式ガイドをご覧ください。

tf.distribute.Strategy を使用して、モデルを読み込んでトレーニングします。

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0503 - accuracy: 0.9845
Epoch 2/2
938/938 [==============================] - 8s 9ms/step - loss: 0.0355 - accuracy: 0.9892

ご覧の通り、tf.distribute.Strategy を使って期待通りに読み込まれました。ここで使用されるストラテジーは、保存前に使用したストラテジーと同じものである必要はありません。

tf.saved_model API

では、低レベル API を見てみましょう。モデルの保存は Keras API に類似しています。

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

読み込みは tf.saved_model.load() で行えますが、より低いレベルにある API(したがって広範なユースケースのある API)であるため、Keras モデルを返しません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。次に例を示します。

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

読み込まれたオブジェクトには複数の関数が含まれ、それぞれにキーが関連付けられている可能性があります。"serving_default" は、保存された Keras モデルを使用した推論関数のデフォルトのキーです。この関数で推論を実行するには、次のようにします。

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.80148745e-01,  1.33514032e-01, -2.87131630e-02,
         7.16523156e-02,  5.96208051e-02, -1.41107485e-01,
        -1.91312388e-01, -2.37503886e-01, -3.35927419e-02,
        -5.19809276e-02],
       [-2.36288786e-01,  5.03980108e-02, -1.62936170e-02,
        -4.73961532e-02,  6.72176480e-02,  1.00629888e-01,
        -7.40617514e-02, -8.68952796e-02, -4.73142453e-02,
        -2.62654185e-01],
       [-1.34671092e-01,  2.75589246e-02, -3.82595509e-02,
         8.22241753e-02,  2.00248569e-01,  5.53470105e-04,
        -6.37755543e-02, -1.81547310e-02,  5.63463122e-02,
        -2.53453195e-01],
       [-3.00725941e-02,  1.14580847e-01, -1.36942551e-01,
         1.96914107e-01, -7.57494345e-02,  3.88330817e-02,
        -4.73932028e-02, -9.81834531e-02,  5.53093925e-02,
         6.25574738e-02],
       [-6.90577328e-02,  4.01481241e-02,  7.58580863e-05,
         8.74225870e-02,  1.08554110e-01, -4.93706316e-02,
        -7.11562335e-02, -1.20140880e-01,  1.04657030e-02,
        -8.07161778e-02],
       [-1.52959123e-01,  2.04001382e-01, -6.20923229e-02,
         1.99777156e-01,  1.34099647e-01,  6.88168854e-02,
        -3.72684970e-02, -1.52993947e-01, -1.06225178e-01,
        -6.87185004e-02],
       [-9.58189666e-02,  1.03578016e-01,  5.63422069e-02,
         1.02696434e-01,  7.79873282e-02,  4.63175029e-02,
        -2.90759206e-02, -2.39548281e-01, -6.80998862e-02,
        -1.75888419e-01],
       [-1.15482561e-01,  1.87766463e-01, -1.92467883e-01,
         1.34843200e-01,  2.81673819e-02,  1.05498143e-01,
         4.87741642e-02, -1.11555547e-01, -3.73479575e-02,
         9.21436101e-02],
       [-2.25425810e-01,  1.47895426e-01,  1.94478929e-02,
         1.77668944e-01,  1.70605719e-01,  1.46347091e-01,
        -5.64843491e-02, -2.27678254e-01, -1.48098990e-01,
        -2.97955386e-02],
       [-1.05673417e-01, -8.63498822e-03, -1.94058135e-01,
         9.95621681e-02, -3.76955494e-02,  1.36648178e-01,
        -2.47661006e-02,  1.94414444e-02, -5.16721681e-02,
         2.19522677e-02],
       [-2.61032432e-02,  1.65352836e-01, -1.02065675e-01,
         2.87719220e-01, -5.69451880e-03,  8.70772228e-02,
        -2.04757065e-01, -1.68713421e-01,  4.19157147e-02,
        -1.29969329e-01],
       [-1.95437104e-01,  1.99485153e-01, -3.33056971e-02,
         2.20946252e-01,  1.19004212e-01,  1.11005664e-01,
        -9.24758017e-02, -8.89133066e-02, -8.90364498e-02,
        -8.05964768e-02],
       [-4.46876436e-02,  6.00830466e-02, -1.53938048e-02,
         1.52039722e-01,  9.05125514e-02,  4.67565246e-02,
        -1.09522790e-03, -2.59150207e-01, -5.84345236e-02,
         6.93883970e-02],
       [-8.20917487e-02,  4.75821830e-02, -6.70246705e-02,
         2.01006785e-01,  2.47687325e-02,  1.15976043e-01,
        -1.12001128e-01, -1.99499398e-01, -2.36950591e-02,
        -5.70925698e-02],
       [-1.84315890e-01,  1.87751845e-01, -1.78156570e-02,
         1.75649762e-01,  3.24658900e-02,  2.98027880e-03,
        -1.06299669e-01, -5.05354106e-02,  1.73205137e-03,
        -2.45052539e-02],
       [-1.45228893e-01,  1.31983459e-01, -2.78562456e-02,
         1.78783461e-01,  1.19794048e-02, -6.04133755e-02,
        -2.23527014e-01, -4.63448241e-02,  6.95200786e-02,
        -1.87706128e-01],
       [-1.04228586e-01,  2.13943012e-02, -1.34355389e-02,
         1.70563966e-01, -8.66526365e-03, -1.23366416e-02,
        -1.44198984e-01, -6.58028498e-02,  5.76171651e-02,
         6.96113054e-03],
       [-1.25256106e-02,  1.78468786e-02, -6.78840131e-02,
         1.37728944e-01,  5.45562617e-02,  1.26407653e-01,
        -4.08719778e-02, -1.13486752e-01, -1.54493749e-03,
        -5.20698875e-02],
       [ 4.62122858e-02,  1.31210506e-01, -2.38493439e-02,
         1.17228568e-01, -2.61356831e-02,  1.13679111e-01,
        -1.19934157e-01, -3.48669030e-02, -3.59272584e-02,
        -2.71111056e-02],
       [-2.17378423e-01,  1.83873564e-01, -5.18436730e-02,
         2.37078965e-01,  4.16856520e-02,  4.85838056e-02,
        -1.38593808e-01, -6.06887136e-03, -1.43550113e-02,
        -7.06137270e-02],
       [-1.57423168e-01,  9.36834812e-02,  8.97897035e-02,
         1.21637769e-01,  1.31796181e-01, -1.65549647e-02,
        -6.42024204e-02, -5.14324754e-03, -4.08546552e-02,
        -1.05977021e-01],
       [-1.48703009e-02, -2.58602649e-02, -1.69074256e-02,
         3.98873910e-02,  9.76919457e-02,  5.71911037e-03,
        -3.73653397e-02, -8.49235132e-02,  6.72312453e-02,
        -1.33198313e-02],
       [-6.34206086e-02,  1.51487142e-02, -1.02290526e-01,
         1.40237004e-01,  6.86008409e-02,  1.94813907e-01,
         8.79306868e-02, -2.19192445e-01, -9.22468454e-02,
        -4.28461432e-02],
       [-6.74271733e-02,  1.97617367e-01,  1.98064953e-01,
         2.82037742e-02,  1.44770145e-02, -6.68774731e-03,
        -9.42453146e-02, -1.13627508e-01, -9.45969895e-02,
        -1.99111566e-01],
       [-1.30129820e-02, -8.39958489e-02, -4.42412421e-02,
         1.34765178e-01,  6.02031462e-02,  1.96596861e-01,
        -8.18365067e-02, -1.45619735e-01, -1.62957266e-01,
         1.95069984e-03],
       [-8.67712125e-02,  3.25389206e-03, -1.14180908e-01,
         1.31985873e-01,  2.01569423e-01,  1.31158561e-01,
         2.54417323e-02, -2.45438382e-01, -4.70624864e-02,
        -1.40900053e-02],
       [-8.00902992e-02,  3.45779657e-02, -1.50647402e-01,
         1.35409802e-01, -6.10385723e-02,  9.71267074e-02,
        -5.65597489e-02, -1.04390785e-01,  2.37799883e-02,
        -1.82536673e-02],
       [-9.35142711e-02,  2.52702534e-01, -3.18659022e-02,
         1.71243101e-01,  9.54027772e-02, -3.82800400e-03,
        -1.23053603e-01,  1.48397125e-02, -1.73588656e-02,
        -1.66330785e-01],
       [-1.02943242e-01,  7.12655038e-02,  3.56826186e-03,
         5.46290874e-02, -6.33837134e-02,  3.07019185e-02,
        -2.03796327e-02,  1.32264495e-02, -5.46035208e-02,
        -3.71890403e-02],
       [-2.06419140e-01,  7.59046897e-02,  6.42012581e-02,
         1.59068286e-01,  1.42954856e-01,  8.00145715e-02,
        -7.57730007e-02, -1.48105875e-01,  6.09865040e-02,
        -9.88312662e-02],
       [-8.47662836e-02,  1.88201651e-01, -1.71128549e-02,
         2.62053698e-01,  1.46339133e-01, -1.07536577e-01,
        -1.78948760e-01, -1.54565811e-01,  4.36789021e-02,
        -1.95843250e-01],
       [ 1.27585530e-02,  8.68005827e-02, -1.02946900e-01,
         1.94241941e-01,  1.64123014e-01,  1.97950214e-01,
         4.28899303e-02, -1.68846846e-01, -9.03922990e-02,
         3.18114907e-02],
       [-1.08418033e-01,  8.62024128e-02,  6.11752942e-02,
         1.10205211e-01,  1.96092814e-01,  1.94776863e-01,
        -4.78504933e-02, -1.81526337e-02, -1.78298637e-01,
         1.71261281e-03],
       [ 2.66609285e-02, -7.09919445e-03, -5.94112761e-02,
         6.90286160e-02, -3.69608551e-02,  1.24270767e-02,
        -9.61745158e-02, -3.92623395e-02,  1.74935497e-02,
        -6.76374137e-03],
       [-6.61516711e-02,  1.35748640e-01, -4.90215495e-02,
         1.68750688e-01,  5.45104742e-02,  8.80523920e-02,
        -9.94434953e-02, -2.45334193e-01, -1.74186528e-01,
        -1.34910107e-01],
       [-1.28005132e-01,  1.31362900e-01,  7.90357366e-02,
         1.23638391e-01,  5.19953631e-02, -2.46395007e-01,
        -2.11327657e-01,  2.35061012e-02,  1.04565367e-01,
        -1.46918595e-01],
       [-2.01322585e-01,  7.69232735e-02, -9.20683444e-02,
         1.00531250e-01,  2.19418138e-01,  1.92265585e-02,
        -1.49796382e-01, -1.71407014e-01, -1.24774791e-01,
        -1.32502884e-01],
       [-1.08242273e-01, -3.59819531e-02,  7.11274445e-02,
         9.92041901e-02, -1.46615505e-02,  5.15304357e-02,
        -8.56937319e-02, -6.29515573e-02,  5.88932820e-02,
        -1.80848707e-02],
       [-1.14341632e-01,  5.85459657e-02, -4.80359755e-02,
         7.93043151e-02,  2.74222828e-02,  2.41690613e-02,
        -3.36848386e-02,  1.59112103e-02,  4.32283431e-03,
        -8.20446610e-02],
       [-2.34901775e-02,  1.00804538e-01,  2.91342884e-02,
         1.78647369e-01, -1.14274524e-01,  3.73394713e-02,
        -8.32463875e-02, -8.76808316e-02,  7.43689612e-02,
        -3.84302288e-02],
       [-9.46902260e-02,  1.83316410e-01, -5.19321933e-02,
         1.19176000e-01,  9.62891579e-02,  1.64513245e-01,
        -1.04709119e-02, -5.51389456e-02, -1.17363684e-01,
        -7.77934268e-02],
       [-3.82615477e-02, -3.67629528e-03, -4.03916314e-02,
         1.12081997e-01,  1.19140416e-01,  5.62198013e-02,
        -4.06569764e-02, -1.84003294e-01, -2.34463289e-02,
        -3.09241600e-02],
       [-1.21338770e-01,  9.93697941e-02, -2.27549188e-02,
         1.63415670e-01,  9.50213596e-02,  8.96587819e-02,
        -3.41728330e-02, -5.34512773e-02, -4.43711840e-02,
        -8.76995735e-03],
       [-1.16379865e-01,  5.37211001e-02, -4.90345992e-02,
         1.44882053e-01,  2.42444009e-01, -3.14327031e-02,
        -7.67233521e-02, -6.67005181e-02, -2.68945061e-02,
        -5.56297600e-02],
       [-9.97139513e-02,  7.09423870e-02,  1.01991177e-01,
         1.43535435e-01,  1.15161538e-01, -1.60139859e-01,
        -1.66834265e-01, -3.67481560e-02,  5.15153185e-02,
        -1.02712616e-01],
       [-8.57860744e-02, -1.12610072e-01, -7.90568590e-02,
         1.05949417e-01,  2.50540107e-01,  1.68823555e-01,
        -2.17120834e-02, -1.22933879e-01, -7.00456277e-02,
        -5.86306602e-02],
       [-1.47669703e-01, -3.68478559e-02, -5.91939762e-02,
         1.18016057e-01,  1.15264602e-01,  8.11087936e-02,
        -8.63386020e-02, -2.58504689e-01,  8.07600096e-03,
        -7.26428553e-02],
       [-1.49649456e-01,  5.63042201e-02,  3.82883176e-02,
         1.14912771e-01,  9.47139189e-02,  4.54579964e-02,
        -6.99455440e-02, -1.71200782e-01, -2.88073197e-02,
        -2.11609229e-01],
       [-5.54483011e-02,  1.00146934e-01, -9.60576907e-02,
         1.83109671e-01,  7.37194493e-02,  1.91168159e-01,
        -6.74408302e-02, -1.50882870e-01, -6.12113550e-02,
        -1.33752763e-01],
       [-9.04634148e-02, -3.53388935e-02, -1.22361317e-01,
         8.95344317e-02,  4.54211123e-02,  1.52289450e-01,
        -1.46961687e-02,  3.73999104e-02, -1.19142778e-01,
         9.15640667e-02],
       [-1.52703285e-01,  3.30357015e-01,  3.76335606e-02,
         2.96404436e-02,  1.62578374e-01,  7.83460289e-02,
         5.31926751e-04, -5.24859875e-02, -1.86474130e-01,
        -2.14560941e-01],
       [-1.26740202e-01, -6.25290424e-02, -5.49526513e-02,
         1.13773897e-01,  7.84891397e-02,  5.27034029e-02,
        -8.14578682e-03, -2.37367094e-01, -3.00243702e-02,
         2.55652480e-02],
       [-1.03566192e-01,  2.99373325e-02, -1.74053665e-02,
         1.47834927e-01,  5.51711097e-02,  1.39649704e-01,
        -8.15002024e-02, -1.37584135e-01,  1.32908411e-02,
        -5.03157005e-02],
       [-2.28102490e-01,  1.54016480e-01,  9.73745361e-02,
         9.99387503e-02, -3.97506319e-02, -2.74454225e-02,
        -6.78737313e-02, -6.45283610e-02, -1.02310404e-02,
        -4.17116582e-02],
       [-2.68130243e-01,  1.38102412e-01,  6.27536625e-02,
         1.75527334e-01,  7.54192397e-02, -7.78935030e-02,
        -1.61849275e-01, -1.18933164e-01,  1.40640512e-03,
        -8.56635794e-02],
       [-9.10969675e-02,  1.76204905e-01, -3.20403352e-02,
         1.16631843e-01,  1.41420662e-01,  8.71218741e-04,
        -2.10343331e-01, -4.64888737e-02, -3.82080674e-03,
        -2.35398933e-01],
       [-8.60667378e-02,  1.90748304e-01, -1.00961223e-01,
         7.83518553e-02,  1.05100065e-01,  9.62026268e-02,
        -3.96616012e-02, -1.07762277e-01, -8.43076408e-02,
        -1.62416585e-02],
       [-1.39651194e-01,  8.80849212e-02, -9.33513194e-02,
         9.95803922e-02,  1.43332034e-03,  2.32762739e-01,
         3.05913612e-02, -5.54927550e-02, -8.69316310e-02,
         3.88408825e-02],
       [-3.37045580e-01,  1.41151592e-01, -4.73942608e-03,
         1.36001334e-01,  8.12572762e-02,  5.91648370e-02,
        -7.05533326e-02, -2.73337178e-02, -2.79111490e-02,
        -2.35716682e-02],
       [-1.60612002e-01,  8.30742866e-02, -8.68474022e-02,
        -4.64118645e-02, -2.08184980e-02, -3.72342840e-02,
        -1.16365984e-01,  1.41748518e-01,  1.16192520e-01,
        -1.86156020e-01],
       [-5.42840548e-02,  2.05266312e-01, -4.47835587e-02,
         1.47917837e-01,  7.08672255e-02, -5.21331169e-02,
        -1.61334664e-01, -1.19976789e-01,  1.44696319e-02,
        -1.21052071e-01],
       [-1.74211279e-01,  1.90779641e-01,  8.71634856e-03,
         1.31276980e-01,  4.35346104e-02, -8.93380493e-02,
        -1.15072154e-01, -9.72218290e-02,  3.98243703e-02,
        -1.02335595e-01],
       [-7.44505227e-02, -8.93713087e-02,  6.19246177e-02,
         1.36901751e-01,  4.21357229e-02,  2.36863568e-02,
        -6.85738474e-02, -1.94917679e-01,  2.22067088e-02,
        -7.60869086e-02],
       [-2.96473689e-02,  6.17601499e-02, -5.22884279e-02,
         1.47381663e-01,  1.22329667e-01,  6.64027259e-02,
        -1.27000526e-01,  7.84885287e-02,  3.67109478e-03,
         1.11098588e-03]], dtype=float32)>}

また、分散方法で読み込んで推論を実行することもできます。

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

復元された関数の呼び出しは、保存されたモデル(predict)に対するフォワードパスです。読み込まれた関数をトレーニングし続ける場合はどうでしょうか。または読み込まれた関数をより大きなモデルに埋め込むには?一般的には、この読み込まれたオブジェクトを Keras レイヤーにラップして行うことができます。幸いにも、TF Hub には、次に示すとおり、この目的に使用できる hub.KerasLayer が用意されています。

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 6s 3ms/step - loss: 0.4257 - accuracy: 0.8715
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0818 - accuracy: 0.9762

ご覧の通り、hub.KerasLayertf.saved_model.load() から読み込まれた結果を、別のモデルの構築に使用できる Keras レイヤーにラップしています。学習を転送する場合に非常に便利な手法です。

どの API を使用すべきですか?

保存の場合は、Keras モデルを使用しているのであれば、Keras の model.save() API をほぼ必ず使用することが推奨されます。保存しているものが Keras モデルでなければ、低レベル API しか使用できません。

読み込みの場合は、読み込み API から得ようとしているものによって選択肢がきまs理ます。Keras モデルを使用できない(または使用を希望しない)のであれば、tf.saved_model.load() を使用し、そうでなければ、tf.keras.models.load_model() を使用します。Keras モデルを保存した場合にのみ Keras モデルを読み込めることに注意してください。

API を混在させることも可能です。model.save で Keras モデルを保存し、低レベルの tf.saved_model.load API を使用して、非 Keras モデルを読み込むことができます。

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

ローカルデバイスで保存または読み込む

クラウド TPU を使用するなど、リモートで実行中にローカルの IO デバイスに保存したり、そこから読み込んだりする場合、experimental_io_device オプションを使用して、IO デバイスを localhost に設定する必要があります。

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

警告

特殊なケースは、入力が十分に定義されていない Keras モデルがある場合です。たとえば、Seeuqntial モデルは、入力形状(Sequential([Dense(3), ...])を使用せずに作成できます。Subclassed モデルにも、初期化後は十分に定義された入力がありません。この場合、保存と読み込みの両方に低レベル API を使用する必要があります。そうしない場合はエラーが発生します。

モデルの入力が十分に定義されたものであるかを確認するには、model.inputsNone であるかどうかを確認します。None でなければ問題ありません。入力形状は、モデルが .fit.evaluate.predict で使用されている場合、またはモデルを呼び出す場合(model(inputs))に自動的に定義されます。

次に例を示します。

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa3c2b3ea58>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa3c2b3ea58>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7fa3b1b572e8>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7fa3b1b572e8>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets