분산 전략을 사용한 모델 저장 및 불러오기

TensorFlow.org에서 보기 구글 코랩(Colab)에서 실행하기 깃허브(GitHub)소스 보기 노트북 다운로드 하기

개요

훈련 도중 모델을 저장하고 불러오는 것은 흔히 일어나는 일입니다. 케라스 모델을 저장하고 불러오기 위한 API에는 high-level API와 low-level API, 두 가지가 있습니다. 이 튜토리얼은 tf.distribute.Strategy를 사용할 때 어떻게 SavedModel APIs를 사용할 수 있는지 보여줍니다. SavedModel과 직렬화에 관한 일반적인 내용을 학습하려면, saved model guideKeras model serialization guide를 읽어보는 것을 권장합니다. 간단한 예로 시작해보겠습니다:

필요한 패키지 가져오기:

!pip install -q tf-nightly
import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()

tf.distribute.Strategy를 사용하는 모델과 데이터 준비하기:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

모델 훈련시키기:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:606: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:606: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 9s 10ms/step - loss: 0.4029 - accuracy: 0.8828
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0726 - accuracy: 0.9794

<tensorflow.python.keras.callbacks.History at 0x7f5030597860>

모델 저장하고 불러오기

이제 사용할 모델을 가지고 있으므로 API를 이용해 모델을 저장하고 불러오는 방법에 대해 살펴봅시다. 두 가지 API를 사용 할 수 있습니다:

케라스 API

케라스 API들을 이용해 모델을 저장하고 불러오는 예를 소개합니다.

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)  # save()는 전략 범위를 벗어나 호출되어야 합니다.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

tf.distribute.Strategy없이 모델 복원시키기:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0500 - accuracy: 0.0988
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0350 - accuracy: 0.0989

<tensorflow.python.keras.callbacks.History at 0x7f50897717f0>

모델을 복원시킨 후에는 compile()이 이미 저장되기 전에 컴파일 되기 때문에, compile()을 다시 호출하지 않고도 모델 훈련을 계속 할 수 있습니다. 그 모델은 텐서플로 표준 SavedModel의 프로토 타입에 저장됩니다. 더 많은 정보를 원한다면, the guide to saved_model format를 참고하세요.

tf.distribute.strategy의 범위를 벗어나서 model.save() 방법을 호출하는 것은 중요합니다. 범위 안에서 호출하는 것은 지원하지 않습니다.

이제 모델을 불러와서 tf.distribute.Strategy를 사용해 훈련시킵니다:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0503 - accuracy: 0.0990
Epoch 2/2
938/938 [==============================] - 8s 9ms/step - loss: 0.0351 - accuracy: 0.0988

위에서 볼 수 있듯이, 불러오기는 tf.distribute.Strategy에서 예상한대로 작동합니다. 여기서 사용된 전략은 이전에 사용된 전략과 같지 않아도 됩니다.

tf.saved_model 형 API

이제 저수준 API에 대해서 살펴봅시다. 모델을 저장하는 것은 케라스 API와 비슷합니다:

model = get_model()  # 새 모델 얻기
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

tf.saved_model.load()로 불러올 수 있습니다. 그러나 저수준 단계의 API이기 때문에 (따라서 더 넓은 사용범위를 갖습니다), 케라스 모델을 반환하지 않습니다. 대신, 추론하기 위해 사용될 수 있는 기능들을 포함한 객체를 반환합니다. 예를 들어:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

불러와진 객체는 각각 키와 관련된 채, 여러 기능을 포함할 수 있습니다. "serving_default"는 저장된 케라스 모델이 있는 추론 기능을 위한 기본 키입니다. 이 기능을 이용하여 추론합니다:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-3.05308700e-02,  9.42159221e-02,  1.39148593e-01,
         2.59458065e-01,  9.99578238e-02,  1.13621376e-01,
         1.22837469e-01, -9.05251950e-02, -5.75954393e-02,
        -1.37231629e-02],
       [-1.18215293e-01,  4.75352854e-02,  4.55700979e-02,
         1.09991908e-01,  6.79713786e-02,  1.43453240e-01,
         1.31532371e-01,  2.66871452e-02,  3.41261923e-02,
        -5.76621145e-02],
       [-5.42557873e-02,  4.16125655e-02, -2.69065108e-02,
         1.13946043e-01,  4.69533838e-02,  4.04659808e-02,
         1.60175487e-01, -4.21117842e-02, -1.28548428e-01,
        -6.17558248e-02],
       [ 4.23514172e-02,  1.77516744e-01,  1.32193819e-01,
         1.77609399e-01,  2.71374024e-02,  1.07948489e-01,
         8.33936930e-02, -6.32070974e-02, -5.61613515e-02,
         6.25407025e-02],
       [-3.15808952e-02,  1.00834541e-01,  1.42571449e-01,
         1.53896630e-01, -2.57933587e-02,  1.04462989e-01,
         8.87114257e-02,  9.18210074e-02, -1.46219358e-02,
        -5.54662198e-02],
       [ 7.58493021e-02,  9.30871069e-02,  1.20021731e-01,
         2.73390621e-01,  1.22209705e-01,  1.03974350e-01,
         9.62646753e-02, -8.87884945e-03, -1.02047347e-01,
         6.92702234e-02],
       [-3.92814726e-02,  7.88600519e-02,  7.54948035e-02,
         9.24558192e-02,  3.61708663e-02,  1.38715580e-01,
         1.02368928e-01,  3.98323499e-02,  7.97639638e-02,
         4.67892289e-02],
       [ 4.17024642e-02,  7.37980083e-02,  9.65877846e-02,
         2.01069355e-01,  9.72416401e-02,  9.10367072e-02,
         6.33519888e-02,  1.92575157e-04, -2.30481476e-03,
         7.28092790e-02],
       [-3.03825196e-02,  4.96896990e-02,  1.62328497e-01,
         1.85395122e-01,  1.25314176e-01,  1.23821124e-02,
         1.10565409e-01,  6.51958734e-02, -2.19849646e-02,
        -9.65922922e-02],
       [-5.44174686e-02,  8.21933895e-02,  1.33803159e-01,
         2.29789063e-01, -4.11741622e-02,  1.38205752e-01,
         1.67635053e-01, -6.12116233e-03, -1.19762219e-01,
         3.25091444e-02],
       [-1.40952080e-01,  9.08887833e-02,  2.91251689e-02,
         1.53779030e-01,  8.86335522e-02,  9.08506587e-02,
         9.01659280e-02,  2.45272256e-02,  6.57167435e-02,
        -1.07677095e-02],
       [ 8.33221078e-02,  2.00671792e-01,  1.36305377e-01,
         1.95762187e-01,  4.19279076e-02,  1.33782148e-01,
         4.73569296e-02, -3.04924473e-02, -4.88719642e-02,
         1.41511671e-02],
       [-1.24567166e-01,  9.12302136e-02,  5.28106764e-02,
         1.78616524e-01, -8.80952179e-03,  9.29361582e-02,
         6.74674809e-02,  1.02033898e-01,  8.18199366e-02,
        -6.59622848e-02],
       [-4.58398722e-02,  9.95524377e-02,  1.74251646e-01,
         2.72552192e-01,  1.43933464e-02,  1.28681451e-01,
         1.70665011e-01,  1.08146176e-01, -1.13081567e-01,
         1.32905729e-02],
       [-5.82559817e-02,  1.13028765e-01,  1.16644904e-01,
         1.75792903e-01,  3.63193005e-02,  1.70971960e-01,
         1.27876028e-01,  2.84908935e-02, -1.46629512e-01,
        -1.49867516e-02],
       [-7.15857670e-02,  9.55703035e-02,  1.73405737e-01,
         1.87259883e-01,  1.32967502e-01,  9.61942226e-02,
         1.83499143e-01, -7.70007074e-03, -1.00596793e-01,
         9.18167084e-03],
       [-1.95470862e-02,  8.08829591e-02,  1.06519714e-01,
         2.05336407e-01, -2.38599516e-02,  6.89057559e-02,
         5.75093776e-02, -8.05974603e-02, -7.23126009e-02,
        -6.57379776e-02],
       [-5.48850931e-02,  1.35918885e-01,  1.13304406e-01,
         2.95743614e-01,  4.76526245e-02,  1.42053440e-01,
         7.89659917e-02,  9.37583596e-02, -4.98776883e-02,
         2.48000268e-02],
       [-1.22502580e-01,  1.25016630e-01,  9.50589404e-03,
         1.89133570e-01, -2.66898423e-02,  1.21076927e-01,
         7.15937391e-02,  8.90770853e-02,  2.15006433e-02,
        -3.57487611e-03],
       [ 3.60565148e-02,  1.53398097e-01,  9.89795253e-02,
         1.91150844e-01,  4.18886691e-02,  1.54804036e-01,
         2.36526765e-02,  3.50568891e-02, -2.00432763e-02,
         4.77854945e-02],
       [-2.39802301e-02,  1.14581592e-01,  1.41991735e-01,
         1.14286378e-01,  4.82569262e-02,  9.05555487e-02,
         3.08044981e-02, -2.26550549e-03, -9.96773839e-02,
         3.80610526e-02],
       [-7.94972479e-02,  7.30776191e-02,  8.20153430e-02,
         1.86292514e-01,  1.16615489e-01,  1.36315852e-01,
         3.21083739e-02,  6.35943860e-02, -2.74502635e-02,
         5.95963299e-02],
       [-6.31697327e-02,  1.03651360e-01, -3.99082527e-02,
         7.33729824e-02,  5.60077131e-02,  5.24186864e-02,
         6.80940747e-02,  1.29262954e-02,  6.97758496e-02,
        -1.87047496e-02],
       [-9.24615711e-02,  3.78198549e-02,  2.04937290e-02,
         1.22255936e-01,  6.69843480e-02,  4.29147780e-02,
         1.44507676e-01,  8.33366811e-03,  2.43378803e-03,
        -2.33358070e-02],
       [-7.15491176e-02,  5.02781011e-02,  6.56792521e-02,
         3.15031290e-01,  8.20720792e-02,  8.91509801e-02,
         1.33529350e-01, -5.94046153e-02, -8.01879168e-02,
         3.37456614e-02],
       [-3.41717564e-02,  1.13426827e-01,  4.76910248e-02,
         3.62241149e-01,  1.29395872e-02,  1.42000630e-01,
         1.22818761e-01, -9.18060914e-03, -1.04545906e-01,
         1.12937670e-02],
       [ 4.19751368e-03,  5.01790345e-02,  8.04085732e-02,
         2.03639805e-01,  1.51441433e-02,  1.47276416e-01,
        -3.23480479e-02, -2.37112492e-02,  2.39663571e-03,
         2.00207438e-02],
       [ 1.17180586e-01,  1.43592864e-01,  1.62555337e-01,
         1.61405534e-01,  1.32865310e-01,  1.16586789e-01,
         4.21385989e-02, -9.75568295e-02, -6.73114210e-02,
         2.55282503e-03],
       [-9.60040838e-02,  8.67848694e-02,  1.04993075e-01,
         2.04877868e-01,  1.88457277e-02,  1.57695130e-01,
         1.03152052e-01,  7.39169717e-02, -7.05725029e-02,
         1.30468514e-04],
       [-3.02449968e-02,  3.33243981e-02,  3.47138103e-03,
         1.28871366e-01,  2.37871353e-02,  1.22256577e-01,
         8.99494141e-02,  5.06338142e-02, -1.37668326e-01,
        -6.49698526e-02],
       [ 1.00729816e-01,  1.29241362e-01,  1.07162639e-01,
         8.88848603e-02,  3.06948945e-02,  1.00396268e-01,
         4.74781692e-02, -4.38556708e-02,  5.03233857e-02,
         4.18240800e-02],
       [-1.40579507e-01,  2.01687030e-02, -7.45846480e-02,
         2.38114923e-01,  8.57804269e-02,  8.62047896e-02,
         1.43703893e-01,  1.00958377e-01, -3.78203616e-02,
        -8.41369480e-02],
       [-1.41147465e-01,  1.29838988e-01,  8.58876854e-04,
         3.19066107e-01, -1.13592893e-02,  1.52232125e-01,
         1.97973758e-01,  6.66081235e-02, -6.20221682e-02,
         2.19634324e-02],
       [-1.10829789e-02,  1.09500170e-01,  5.49390949e-02,
         1.71755478e-01, -3.95288840e-02,  1.53138429e-01,
         7.43897632e-02,  3.91403474e-02, -2.79280543e-03,
         5.96231073e-02],
       [-2.61766557e-02, -3.01932003e-02,  8.98309648e-02,
         2.33614668e-01,  4.68735732e-02,  4.90724146e-02,
         5.57917953e-02,  1.31877422e-01,  2.77928673e-02,
         2.07930692e-02],
       [-1.79962695e-01, -3.40987369e-02,  1.46667495e-01,
         4.18333821e-02,  6.61281496e-02,  7.39880875e-02,
         1.94642663e-01, -6.57103285e-02, -9.73661393e-02,
        -1.03660852e-01],
       [-1.19844802e-01,  1.07397877e-01,  6.74315840e-02,
         2.13215679e-01,  1.35276765e-02,  6.68533146e-03,
         3.00458789e-01, -5.01895100e-02, -6.75169230e-02,
        -8.93271193e-02],
       [-2.04042383e-02,  6.99680969e-02,  9.49338377e-02,
         1.77840829e-01, -5.10394722e-02,  1.58895195e-01,
         3.61789092e-02, -2.84739435e-02, -1.21041507e-01,
        -5.23653366e-02],
       [ 2.50779986e-02,  8.90526623e-02,  6.50501400e-02,
         1.37161136e-01, -2.27107834e-02,  7.68830329e-02,
         1.60356425e-03, -8.90311971e-03, -6.31594583e-02,
         2.45019607e-02],
       [-8.80438238e-02,  4.86642458e-02,  6.78685531e-02,
         1.23344623e-01,  9.76668671e-03,  2.24469960e-01,
         1.82289109e-01,  2.81914845e-02, -7.11309910e-02,
         2.76377983e-03],
       [-7.89051428e-02,  6.38186410e-02, -8.83733854e-03,
         2.48986319e-01,  2.44901776e-02,  1.02921143e-01,
         1.41779929e-01,  5.53068072e-02,  5.22629917e-03,
         6.55645877e-02],
       [ 1.26489624e-03,  1.70556366e-01,  1.47089809e-01,
         3.30938578e-01, -8.60022083e-02,  1.48785368e-01,
         1.06518157e-01, -2.35850811e-02, -6.81887344e-02,
         1.82722658e-02],
       [-8.12983438e-02,  6.77574947e-02, -1.35590062e-02,
         1.70525134e-01,  7.61996582e-03,  1.44189700e-01,
         7.22534359e-02,  6.79678470e-02, -1.36843711e-01,
        -1.86140854e-02],
       [-3.68033536e-03,  8.31382349e-02,  6.16342202e-02,
         1.99844837e-01,  3.57769802e-02,  1.23283863e-01,
         1.13068134e-01, -1.11403391e-01, -1.42050087e-01,
        -2.33932957e-03],
       [-1.76784605e-01,  1.08956546e-02,  1.33571938e-01,
         1.58880860e-01,  5.69359995e-02,  1.32433429e-01,
         1.63774267e-01,  2.90980339e-02, -7.83846825e-02,
        -5.91910332e-02],
       [-6.47206903e-02,  2.06422508e-01,  2.52932422e-02,
         3.68850142e-01,  6.71294928e-02,  2.05395073e-01,
         6.04149066e-02, -1.66797414e-01, -1.20290846e-01,
         7.00949319e-03],
       [-1.10943884e-01,  8.53390917e-02,  2.35678311e-02,
         1.99900538e-01,  1.14156678e-02,  9.46220532e-02,
         9.31385756e-02,  3.36804464e-02, -1.80255696e-02,
        -9.76564456e-03],
       [-5.00287786e-02,  7.69369379e-02,  1.46339118e-01,
         1.92820325e-01,  4.97572310e-02,  1.52409375e-01,
         7.30632693e-02,  2.06344686e-02, -2.70681456e-02,
         4.03697044e-03],
       [ 3.21245231e-02,  1.32753521e-01,  7.28847384e-02,
         1.39267266e-01, -3.39897685e-02,  1.22659497e-01,
         8.88233185e-02,  4.95634750e-02,  7.21887201e-02,
        -1.21230818e-02],
       [-9.94295105e-02,  8.53900537e-02,  4.04555686e-02,
         1.77866250e-01,  1.19899437e-02,  1.32760614e-01,
         2.01229051e-01, -5.93775883e-02, -1.08865559e-01,
         9.12826061e-02],
       [-1.47132352e-01,  3.04047726e-02,  5.64383715e-03,
         1.49875253e-01,  6.95666522e-02,  1.49573833e-02,
         1.33015290e-01,  4.75980341e-02,  1.22350305e-02,
        -8.10856968e-02],
       [-1.16204470e-01,  1.28638715e-01,  4.29443158e-02,
         1.28934473e-01, -2.76096202e-02,  8.30805674e-02,
         1.11878529e-01,  9.57200453e-02,  5.34848571e-02,
        -2.59837043e-02],
       [-1.13575660e-01,  5.15359119e-02,  2.25897413e-02,
         1.75781325e-01,  4.44281325e-02,  1.14776023e-01,
         6.08888008e-02,  9.55930129e-02, -6.79778084e-02,
        -2.89319772e-02],
       [ 4.07750309e-02,  1.64727062e-01,  9.18954164e-02,
         1.19012840e-01,  1.39762722e-02,  9.33789313e-02,
         3.31542790e-02,  3.18180211e-03, -6.00981675e-02,
         1.53117459e-02],
       [ 3.72102261e-02,  1.54006407e-01,  5.22531383e-02,
         1.40573159e-01,  8.81843269e-03,  1.27862558e-01,
         4.58463319e-02,  1.54703110e-03, -7.87201524e-02,
         6.17257413e-03],
       [ 5.18210791e-02,  1.63049430e-01,  1.42841041e-01,
         2.47944996e-01,  1.46351025e-01,  1.07094392e-01,
         4.75151576e-02, -1.16263784e-01, -6.07696995e-02,
         3.87751460e-02],
       [-1.53353870e-01,  2.69541293e-02, -1.11502215e-01,
         1.10112257e-01,  5.42939752e-02,  5.79805374e-02,
         8.97530615e-02,  6.24305941e-02,  5.22303283e-02,
        -2.82798689e-02],
       [-2.58226711e-02,  3.32940631e-02,  9.42090303e-02,
         2.84190744e-01,  9.00745392e-03,  6.21661991e-02,
         3.94693837e-02,  6.47234693e-02, -5.49337268e-03,
         4.61190119e-02],
       [-8.69339332e-03,  5.93716577e-02,  1.77013576e-01,
         1.82569534e-01,  5.83685115e-02,  1.00714907e-01,
         1.59753859e-01, -2.82348916e-02, -7.59699568e-02,
         3.36382315e-02],
       [-6.99141547e-02,  1.87803641e-01,  1.11151263e-01,
         2.21301556e-01,  1.04689009e-01,  8.04772750e-02,
         1.57207608e-01, -1.17689863e-01, -4.34901081e-02,
        -4.29998115e-02],
       [-1.16308793e-01,  6.72445521e-02, -2.64352635e-02,
         1.97506219e-01,  3.62709239e-02,  4.76512238e-02,
         1.29374683e-01, -1.62333399e-02, -2.38205455e-02,
         2.33233199e-02],
       [ 2.25486550e-02,  8.09184164e-02,  1.35604218e-01,
         9.25658718e-02, -3.79515104e-02,  1.27319679e-01,
         1.88720152e-02,  2.55733803e-02, -4.37369682e-02,
        -3.84777598e-02],
       [ 7.59366378e-02,  9.42524970e-02,  2.07312971e-01,
         7.22936839e-02, -9.00087506e-02,  1.23255953e-01,
         6.11338355e-02, -1.40520334e-02, -2.44119912e-02,
        -1.40706263e-03],
       [-1.40070543e-01,  4.70167361e-02,  1.99136883e-03,
         1.45174414e-01, -1.94830671e-02,  8.08222890e-02,
         6.18007667e-02,  5.51381111e-02,  9.80618689e-03,
        -5.74187301e-02]], dtype=float32)>}

또한 분산방식으로 불러오고 추론할 수 있습니다:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # 분산방식으로 기능 호출하기
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

복원된 기능을 호출하는 것은 단지 저장된 모델로의 정방향 패쓰입니다(예상하기에). 만약 계속해서 불러온 기능을 훈련시키고 싶다면 어떻게 하실건가요? 불러온 기능을 더 큰 모델에 내장시킬 건가요? 일반적인 방법은 이 불러온 객체를 케라스 층에 싸서(wrap) 달성하는 것입니다. 다행히도, TF Hub는 이 목적을 위해 hub.KerasLayer을 갖고 있으며, 다음과 같습니다.

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # KerasLayer로 감싸기
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Epoch 1/2
938/938 [==============================] - 5s 5ms/step - loss: 0.4261 - accuracy: 0.8801
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0762 - accuracy: 0.9783

볼 수 있듯이, hub.KerasLayertf.saved_model.load()로부터 불려온 결과를 또 다른 모델을 만드는데 사용될 수 있는 케라스 층으로 포장(wrap)합니다. 이것은 학습에 매우 유용합니다.

어떤 API를 사용해야 할까요?

저장에 관해서, 케라스 모델을 사용하는 경우, 케라스의 model.save() API를 사용하는 것을 권장합니다. 저장하려는 모델이 케라스 모델이 아닌 경우, 더 낮은 단계의 API를 선택해야 합니다.

모델을 불러옴에 있어서, 어떤 API를 사용하느냐는 로딩 API에서 얻고자 하는 내용에 따라 결정됩니다. 케라스 모델을 가져올 수 없으면(또는 가져오고 싶지 않다면), tf.saved_model.load()를 사용합니다. 그 외의 경우에는, tf.keras.models.load_model()을 사용합니다. 케라스 모델을 저장한 경우에만 케라스 모델을 반환 받을 수 있다는 점을 유의하세요.

API들을 목적에 따라 혼합하고 짜 맞추는 것이 가능합니다. 케라스 모델을 model.save와 함께 저장할 수 있고, 저수준 API인, tf.saved_model.load로 케라스가 아닌 모델을 불러올 수 있습니다.

model = get_model()

# 케라스의 save() API를 사용하여 모델 저장하기
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# 저수준 API를 사용하여 모델 불러오기
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

주의사항

특별한 경우는 잘 정의되지 않은 입력을 갖는 케라스 모델을 갖고 있는 경우입니다. 예를 들어, 순차 모델은 입력 형태(Sequential([Dense(3), ...]) 없이 만들 수 있습니다. 하위 분류된 모델들 또한 초기화 후에 잘 정의된 입력을 갖고 있지 않습니다. 이 경우 모델을 저장하고 불러올 시 저수준 API를 사용해야 하며, 그렇지 않으면 오류가 발생할 수 있습니다.

모델이 잘 정의된 입력을 갖는지 확인하려면, model.inputsNone인지 확인합니다. None이 아니라면 잘 정의된 입력입니다. 입력 형태들은 모델이 .fit, .evaluate, .predict에서 쓰이거나 모델을 호출 (model(inputs)) 할 때 자동으로 정의됩니다.

예시를 살펴봅시다:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # 오류! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f5088eb83c8>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f5088eb83c8>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f5088eb8fd0>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f5088eb8fd0>, because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets