분산 전략을 사용한 모델 저장 및 불러오기

TensorFlow.org에서 보기 구글 코랩(Colab)에서 실행하기 깃허브(GitHub)소스 보기 노트북 다운로드 하기

개요

훈련 도중 모델을 저장하고 불러오는 것은 흔히 일어나는 일입니다. 케라스 모델을 저장하고 불러오기 위한 API에는 high-level API와 low-level API, 두 가지가 있습니다. 이 튜토리얼은 tf.distribute.Strategy를 사용할 때 어떻게 SavedModel APIs를 사용할 수 있는지 보여줍니다. SavedModel과 직렬화에 관한 일반적인 내용을 학습하려면, saved model guideKeras model serialization guide를 읽어보는 것을 권장합니다. 간단한 예로 시작해보겠습니다:

필요한 패키지 가져오기:

!pip install -q tf-nightly
import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

tf.distribute.Strategy를 사용하는 모델과 데이터 준비하기:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)

모델 훈련시키기:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 14s 11ms/step - loss: 0.3931 - accuracy: 0.8903
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0708 - accuracy: 0.9794

<tensorflow.python.keras.callbacks.History at 0x7f41006be9b0>

모델 저장하고 불러오기

이제 사용할 모델을 가지고 있으므로 API를 이용해 모델을 저장하고 불러오는 방법에 대해 살펴봅시다. 두 가지 API를 사용 할 수 있습니다:

케라스 API

케라스 API들을 이용해 모델을 저장하고 불러오는 예를 소개합니다.

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)  # save()는 전략 범위를 벗어나 호출되어야 합니다.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2289: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:1377: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`layer.updates` will be removed in a future version. '

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

tf.distribute.Strategy없이 모델 복원시키기:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0469 - accuracy: 0.9858
Epoch 2/2
938/938 [==============================] - 8s 9ms/step - loss: 0.0325 - accuracy: 0.9899

<tensorflow.python.keras.callbacks.History at 0x7f40dc32f080>

모델을 복원시킨 후에는 compile()이 이미 저장되기 전에 컴파일 되기 때문에, compile()을 다시 호출하지 않고도 모델 훈련을 계속 할 수 있습니다. 그 모델은 텐서플로 표준 SavedModel의 프로토 타입에 저장됩니다. 더 많은 정보를 원한다면, the guide to saved_model format를 참고하세요.

tf.distribute.strategy의 범위를 벗어나서 model.save() 방법을 호출하는 것은 중요합니다. 범위 안에서 호출하는 것은 지원하지 않습니다.

이제 모델을 불러와서 tf.distribute.Strategy를 사용해 훈련시킵니다:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0481 - accuracy: 0.9850
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0323 - accuracy: 0.9901

위에서 볼 수 있듯이, 불러오기는 tf.distribute.Strategy에서 예상한대로 작동합니다. 여기서 사용된 전략은 이전에 사용된 전략과 같지 않아도 됩니다.

tf.saved_model 형 API

이제 저수준 API에 대해서 살펴봅시다. 모델을 저장하는 것은 케라스 API와 비슷합니다:

model = get_model()  # 새 모델 얻기
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

tf.saved_model.load()로 불러올 수 있습니다. 그러나 저수준 단계의 API이기 때문에 (따라서 더 넓은 사용범위를 갖습니다), 케라스 모델을 반환하지 않습니다. 대신, 추론하기 위해 사용될 수 있는 기능들을 포함한 객체를 반환합니다. 예를 들어:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

불러와진 객체는 각각 키와 관련된 채, 여러 기능을 포함할 수 있습니다. "serving_default"는 저장된 케라스 모델이 있는 추론 기능을 위한 기본 키입니다. 이 기능을 이용하여 추론합니다:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.31710002e-01,  4.69997153e-02, -5.89058809e-02,
        -7.89453760e-02,  5.00574857e-02,  5.67903817e-02,
        -1.23740345e-01, -1.68794781e-01, -9.27029625e-02,
        -2.51288135e-02],
       [-5.54974042e-02,  4.41712886e-02,  4.67255451e-02,
        -3.14179927e-01, -8.12758040e-03, -1.25043824e-01,
         1.11420020e-01, -1.41085178e-01, -2.66745295e-02,
         1.81214467e-01],
       [-1.52560011e-01,  1.73152894e-01,  1.51429757e-01,
        -1.92452252e-01,  6.56868564e-04, -5.58338277e-02,
        -1.48804402e-02, -1.08949229e-01, -2.73743093e-01,
         2.33208835e-01],
       [-1.47220954e-01,  4.86791804e-02,  4.00493629e-02,
        -1.38356254e-01, -1.55049503e-01,  3.04028448e-02,
        -2.39117034e-02, -1.78863198e-01, -1.02356449e-01,
         1.63152605e-01],
       [-3.83402146e-02,  8.73908177e-02, -2.64412500e-02,
        -1.59144387e-01,  2.08970089e-03, -3.62645015e-02,
        -2.08238680e-02, -5.04877605e-02, -1.03073210e-01,
         4.57865223e-02],
       [-1.55467838e-01,  1.53151318e-01,  8.46455526e-03,
        -2.92723507e-01, -1.38505787e-01,  2.46969499e-02,
        -4.37425300e-02, -2.59471953e-01, -9.27083716e-02,
         1.50457397e-01],
       [ 7.47342920e-03,  6.51581958e-02,  2.24648677e-02,
        -3.74905109e-01,  7.72060007e-02, -2.05429822e-01,
         1.42279223e-01, -2.03866616e-01, -1.43216476e-01,
         1.50174573e-01],
       [-2.21392199e-01,  3.55903916e-02, -6.64935857e-02,
        -1.44605845e-01, -9.43055376e-02,  4.91402186e-02,
        -3.01179551e-02, -2.77260333e-01, -8.63499865e-02,
         8.49083737e-02],
       [-5.22387400e-02,  1.85074970e-01, -4.93934825e-02,
        -2.32034683e-01, -5.27238473e-02, -9.60701406e-02,
         1.04433589e-01, -1.57325357e-01, -8.58974010e-02,
         2.14043185e-01],
       [-1.97679460e-01,  6.41728342e-02,  8.06091875e-02,
        -1.11604221e-01, -1.03174940e-01,  9.69565138e-02,
        -7.05495402e-02, -1.35097876e-01,  2.88576055e-02,
         1.62075996e-01],
       [-1.50621310e-01,  8.89200866e-02, -7.73850307e-02,
        -2.46047735e-01, -7.31822327e-02, -1.94585882e-03,
         4.40529808e-02, -2.12117821e-01, -8.78427178e-02,
         1.60316110e-01],
       [ 1.90468263e-02,  1.33927181e-01,  1.42686572e-02,
        -3.14610422e-01, -9.46181118e-02, -3.81766558e-02,
         1.44467920e-01, -2.50842303e-01, -4.85255271e-02,
         9.90096182e-02],
       [-6.63781390e-02,  1.77936614e-01, -2.99972966e-02,
        -9.23760459e-02, -4.85758763e-03, -1.31934509e-01,
         9.16874483e-02, -2.11921008e-03, -1.27245948e-01,
         1.21350765e-01],
       [-1.94299594e-01,  1.20598502e-01, -6.00222573e-02,
        -2.14238107e-01, -6.61890209e-03,  6.19555488e-02,
        -4.54728752e-02, -2.22066879e-01, -6.29875958e-02,
         5.50101288e-02],
       [-1.89868703e-01,  1.23439826e-01, -5.72721995e-02,
        -1.86088681e-01,  3.22347730e-02,  4.71101515e-02,
         1.32112494e-02, -2.46812001e-01, -1.65517509e-01,
         1.70899794e-01],
       [-2.37572849e-01,  7.13197440e-02, -1.10374898e-01,
        -2.92222261e-01, -2.01209523e-02, -1.52642268e-03,
        -8.65244940e-02, -2.74755478e-01, -1.67763025e-01,
         1.86234102e-01],
       [-2.34291941e-01,  7.28301629e-02, -7.74543583e-02,
        -3.01767707e-01, -3.98506373e-02,  1.46720614e-02,
        -3.79977487e-02, -2.87194639e-01, -1.63134843e-01,
         1.58531561e-01],
       [-1.14712827e-01,  2.31486917e-01, -1.11679450e-01,
        -1.80989966e-01, -6.16163909e-02, -6.43817782e-02,
         7.69487303e-03, -2.10871249e-01, -7.00942352e-02,
         1.38534650e-01],
       [-7.92532116e-02,  7.73300529e-02,  1.22292396e-02,
        -1.39936939e-01, -1.34650869e-02, -6.53443411e-02,
         4.82317656e-02, -1.20916925e-01, -6.96138963e-02,
         1.05453297e-01],
       [-1.63892403e-01,  1.84797551e-02, -1.92084417e-01,
        -2.49666631e-01, -2.61558555e-02, -7.30722621e-02,
         7.79571831e-02, -3.07394713e-01, -2.03693733e-01,
         1.51574135e-01],
       [-2.17261195e-01,  2.29646742e-01,  4.66779955e-02,
        -1.18903786e-01,  3.22985947e-02,  2.45849416e-02,
        -4.92762327e-02, -2.28806779e-01, -1.68327525e-01,
         1.23770081e-01],
       [-1.62324741e-01,  1.97496653e-01, -4.72196564e-02,
        -7.94636458e-02, -3.82318087e-02,  6.89196661e-02,
         6.70473129e-02, -1.38797700e-01, -1.29535258e-01,
         7.66131133e-02],
       [-2.55516283e-02,  2.01869339e-01, -9.33012292e-02,
        -1.05885521e-01, -8.76843482e-02, -2.48237029e-01,
         1.41079038e-01,  1.33428704e-02, -1.26833051e-01,
         2.09878638e-01],
       [-8.15958306e-02,  1.12188198e-01, -4.23568934e-02,
        -2.66130358e-01, -5.01424633e-02, -1.05463155e-01,
         1.02032656e-02, -4.05757539e-02,  2.25119526e-03,
         1.29559264e-01],
       [-3.66339386e-02,  2.90437024e-02, -1.27368569e-02,
        -1.77035138e-01, -1.71542689e-01, -1.30576864e-01,
         5.47315702e-02,  8.96204263e-03, -5.72763458e-02,
         2.34254017e-01],
       [-1.56500295e-01,  1.53565541e-01, -2.28636363e-03,
        -6.62139207e-02, -3.27328742e-02,  2.93045882e-02,
         8.53747949e-02, -1.80385932e-01, -5.65336086e-02,
         1.08000956e-01],
       [-1.65389195e-01,  5.24501018e-02, -4.41066660e-02,
        -2.00338721e-01, -1.96185365e-01,  5.90052642e-02,
        -1.26092210e-01, -1.47565231e-01, -4.03641611e-02,
         1.30440056e-01],
       [-8.40880573e-02,  1.74001738e-01,  7.26161972e-02,
        -2.40350902e-01, -5.72482757e-02,  3.98018137e-02,
         1.17004618e-01, -1.99786767e-01, -1.22575566e-01,
         1.28435671e-01],
       [-2.14175552e-01,  1.20554797e-01, -8.64092037e-02,
        -1.62089050e-01, -5.25108278e-02,  5.40847518e-02,
        -6.55080006e-02, -1.35221720e-01, -1.54621825e-01,
         9.57419202e-02],
       [-2.01131046e-01,  1.06524661e-01,  1.80122983e-02,
        -1.79245383e-01, -7.35889599e-02,  1.68863423e-02,
         1.57249700e-02, -1.22766688e-01, -2.22174674e-01,
         1.55999824e-01],
       [-1.73681706e-01,  2.71402802e-02, -6.36542439e-02,
        -2.33296618e-01,  5.99484937e-03, -8.22732747e-02,
         4.46838280e-03, -3.60670090e-01, -1.90715417e-01,
         1.74717352e-01],
       [-1.50259972e-01,  1.19764805e-01, -1.30733386e-01,
        -1.24432303e-01,  5.68290241e-02, -9.72339660e-02,
         5.98696843e-02, -1.29928470e-01, -1.39333844e-01,
         1.54494181e-01],
       [-2.02477291e-01,  6.93696439e-02, -1.64154377e-02,
        -2.08582237e-01, -4.06453758e-02,  6.82849586e-02,
         6.64708838e-02, -3.71645361e-01, -7.36382082e-02,
         1.33133501e-01],
       [-2.91123018e-02,  1.13083079e-01, -5.78429513e-02,
        -1.17793962e-01,  3.83841023e-02, -6.84297532e-02,
        -7.53869042e-02, -1.19341440e-01, -4.64585349e-02,
         3.49517837e-02],
       [-5.77876680e-02,  1.11637473e-01, -6.39148988e-03,
        -1.79888755e-01,  5.42832632e-03, -3.55358124e-02,
        -1.36482995e-02, -1.63555056e-01,  8.18812251e-02,
         8.08460638e-03],
       [-2.27064252e-01,  1.60249267e-02, -6.07033782e-02,
        -3.98558080e-01,  7.34167695e-02, -5.85101917e-02,
        -1.54307717e-02, -2.20064804e-01, -2.14347333e-01,
         2.11124346e-01],
       [-1.12135457e-02,  2.18757659e-01, -2.20545866e-02,
        -7.66985267e-02,  6.23554625e-02, -1.23242266e-01,
         7.94550553e-02, -1.17306106e-01, -1.00016348e-01,
         7.35235736e-02],
       [-1.80582091e-01,  7.76527375e-02, -9.21080783e-02,
        -3.21816146e-01, -7.20392838e-02, -2.17508823e-02,
         2.00027265e-02, -2.27342859e-01, -1.80210322e-01,
         2.26745576e-01],
       [-1.26630962e-01,  1.67847186e-01, -4.95181270e-02,
        -1.84193403e-01,  1.53434277e-03, -9.49611515e-02,
         3.41531672e-02, -1.75902978e-01, -7.35643804e-02,
         1.04903586e-01],
       [-2.41142154e-01,  6.06171973e-02, -6.76750392e-02,
        -1.87891945e-01, -3.97866219e-02,  3.21487859e-02,
         5.36228605e-02, -2.19879225e-01, -1.72935262e-01,
         1.44516155e-01],
       [-3.01095489e-02,  1.34722009e-01, -1.04549803e-01,
        -2.10784599e-01,  1.03664540e-01, -1.65217251e-01,
         1.22614846e-01, -1.58423603e-01, -1.27167642e-01,
         5.74181266e-02],
       [-1.88275278e-01,  3.01557928e-01,  7.80426115e-02,
        -1.58747271e-01, -3.58056873e-02,  4.67294604e-02,
         1.33869842e-01, -1.95851907e-01, -1.37279749e-01,
         1.60734072e-01],
       [-2.03269720e-01,  1.67486995e-01,  1.03996946e-02,
        -1.46267176e-01, -6.27027964e-03,  5.10909930e-02,
         4.64825928e-02, -1.72956482e-01, -1.41186193e-01,
         1.83070302e-01],
       [ 3.57957877e-04,  1.42964885e-01,  1.80653855e-03,
        -1.83541790e-01, -7.42951594e-03, -5.74298017e-02,
        -2.49897614e-02, -1.04883276e-01, -2.71756023e-01,
         1.06882133e-01],
       [-2.62342334e-01,  1.42635107e-02, -2.48182025e-02,
        -1.63350329e-01,  2.29158327e-02,  9.97220501e-02,
        -2.65001059e-02, -1.65731251e-01, -1.82347804e-01,
         9.64073464e-02],
       [-1.23267718e-01,  1.98346987e-01,  3.24245319e-02,
        -2.13430122e-01, -4.33703735e-02, -6.14091940e-02,
         9.64062884e-02, -1.04775526e-01, -1.75305292e-01,
         2.57383853e-01],
       [-1.68140769e-01,  1.36441827e-01,  1.08884182e-02,
        -9.64754224e-02, -2.50152405e-02, -9.46446229e-03,
        -5.93259335e-02, -1.07091017e-01, -8.96655619e-02,
         1.96491078e-01],
       [-4.37381724e-03,  1.04307540e-01,  3.18750553e-02,
        -4.57363665e-01,  1.52269434e-02, -1.47090703e-01,
         1.77273542e-01, -2.10172892e-01, -9.03630555e-02,
         1.54496998e-01],
       [-1.35192886e-01,  2.94696316e-02, -1.39520437e-01,
        -2.53107995e-01, -1.16831122e-03, -1.32040858e-01,
         1.00426756e-01, -1.57828271e-01, -2.56446660e-01,
         1.84354991e-01],
       [-2.04489127e-01,  4.25634570e-02,  5.64568341e-02,
        -1.34944513e-01, -1.50223821e-01,  1.17938139e-01,
        -1.42907063e-02, -1.43584982e-01, -3.76214087e-02,
         1.15251869e-01],
       [ 3.75817120e-02,  2.03622356e-01, -7.83001631e-02,
        -2.86667168e-01,  1.24724843e-01, -2.13014230e-01,
         9.05650854e-02, -1.19515017e-01, -1.78294852e-01,
         4.02432866e-02],
       [-5.27390689e-02,  2.08116159e-01,  1.04225648e-03,
        -1.05647072e-01, -3.11201848e-02, -8.59153792e-02,
        -4.90853004e-03, -1.39914704e-02, -1.43865094e-01,
         1.31246611e-01],
       [-1.11234449e-01,  1.14173517e-01,  5.83788157e-02,
        -1.15305237e-01,  1.92318903e-03, -6.85511436e-03,
         1.45244170e-02, -9.58509073e-02, -1.64572313e-01,
         1.39922887e-01],
       [-1.54754862e-01,  1.12024948e-01, -1.05628334e-02,
        -2.01587588e-01,  2.00379938e-02, -1.57870334e-02,
        -5.81868887e-02, -1.31076097e-01, -1.26282051e-01,
         4.07625437e-02],
       [-2.24229023e-01,  5.61567619e-02, -3.60122472e-02,
        -2.16083288e-01,  4.87294085e-02, -4.82604317e-02,
        -6.70362264e-02, -2.40707800e-01, -2.26278216e-01,
         1.35704428e-01],
       [-1.75218150e-01,  2.40598410e-01, -2.64640544e-02,
        -2.43758753e-01, -4.32628170e-02, -1.73742871e-03,
        -1.46946525e-02, -2.60046333e-01, -1.43145313e-02,
         1.53031796e-01],
       [-1.29570216e-01,  1.22143552e-01, -1.19005062e-01,
        -1.88355431e-01,  6.29518256e-02, -1.85053155e-01,
         1.39260024e-01, -1.79315194e-01, -1.87190771e-01,
         1.42009690e-01],
       [-2.56315738e-01,  1.26079351e-01, -5.71743995e-02,
        -2.76979387e-01, -2.74480134e-02,  2.68090926e-02,
        -2.47842458e-04, -3.72705132e-01, -2.36447796e-01,
         1.77708328e-01],
       [-1.81314021e-01,  1.30829737e-01, -2.47730184e-02,
        -3.36653829e-01, -1.18225059e-02, -6.70068935e-02,
        -3.33150066e-02, -2.58504391e-01, -1.88907459e-01,
         2.76795864e-01],
       [-1.96038321e-01,  1.03035390e-01, -3.73757035e-02,
        -2.18225539e-01, -9.72566567e-03, -3.53503712e-02,
         7.75984153e-02, -2.07471773e-01, -6.30558878e-02,
         1.32758439e-01],
       [-1.17597811e-01,  1.52661681e-01,  3.57742868e-02,
        -4.86104824e-02,  1.18510321e-01, -4.65476774e-02,
         1.01888329e-02, -1.99031487e-01, -1.81028336e-01,
         1.55515701e-01],
       [-1.11166604e-01,  1.70417845e-01,  4.58976217e-02,
        -3.81588846e-01,  8.12899694e-02, -1.33084998e-01,
        -3.65418009e-02, -2.46883601e-01, -2.13971853e-01,
         2.10357890e-01],
       [-1.26382425e-01,  3.21384557e-02, -3.26598883e-02,
        -2.95765907e-01, -8.71679634e-02, -2.42257286e-02,
         4.33042273e-02, -2.21485719e-02, -2.05287889e-01,
         2.81905532e-01],
       [-1.02467999e-01,  1.65629908e-01,  2.73847040e-02,
        -1.95608556e-01,  4.34939824e-02, -5.75105920e-02,
         7.93466158e-03, -2.23078132e-01, -1.31911159e-01,
         1.82216957e-01]], dtype=float32)>}

또한 분산방식으로 불러오고 추론할 수 있습니다:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # 분산방식으로 기능 호출하기
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.

Warning:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

복원된 기능을 호출하는 것은 단지 저장된 모델로의 정방향 패쓰입니다(예상하기에). 만약 계속해서 불러온 기능을 훈련시키고 싶다면 어떻게 하실건가요? 불러온 기능을 더 큰 모델에 내장시킬 건가요? 일반적인 방법은 이 불러온 객체를 케라스 층에 싸서(wrap) 달성하는 것입니다. 다행히도, TF Hub는 이 목적을 위해 hub.KerasLayer을 갖고 있으며, 다음과 같습니다.

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # KerasLayer로 감싸기
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.

Warning:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)

Epoch 1/2
938/938 [==============================] - 11s 9ms/step - loss: 0.4051 - accuracy: 0.8846
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0790 - accuracy: 0.9771

볼 수 있듯이, hub.KerasLayertf.saved_model.load()로부터 불려온 결과를 또 다른 모델을 만드는데 사용될 수 있는 케라스 층으로 포장(wrap)합니다. 이것은 학습에 매우 유용합니다.

어떤 API를 사용해야 할까요?

저장에 관해서, 케라스 모델을 사용하는 경우, 케라스의 model.save() API를 사용하는 것을 권장합니다. 저장하려는 모델이 케라스 모델이 아닌 경우, 더 낮은 단계의 API를 선택해야 합니다.

모델을 불러옴에 있어서, 어떤 API를 사용하느냐는 로딩 API에서 얻고자 하는 내용에 따라 결정됩니다. 케라스 모델을 가져올 수 없으면(또는 가져오고 싶지 않다면), tf.saved_model.load()를 사용합니다. 그 외의 경우에는, tf.keras.models.load_model()을 사용합니다. 케라스 모델을 저장한 경우에만 케라스 모델을 반환 받을 수 있다는 점을 유의하세요.

API들을 목적에 따라 혼합하고 짜 맞추는 것이 가능합니다. 케라스 모델을 model.save와 함께 저장할 수 있고, 저수준 API인, tf.saved_model.load로 케라스가 아닌 모델을 불러올 수 있습니다.

model = get_model()

# 케라스의 save() API를 사용하여 모델 저장하기
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# 저수준 API를 사용하여 모델 불러오기
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Warning:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.

Warning:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)

주의사항

특별한 경우는 잘 정의되지 않은 입력을 갖는 케라스 모델을 갖고 있는 경우입니다. 예를 들어, 순차 모델은 입력 형태(Sequential([Dense(3), ...]) 없이 만들 수 있습니다. 하위 분류된 모델들 또한 초기화 후에 잘 정의된 입력을 갖고 있지 않습니다. 이 경우 모델을 저장하고 불러올 시 저수준 API를 사용해야 하며, 그렇지 않으면 오류가 발생할 수 있습니다.

모델이 잘 정의된 입력을 갖는지 확인하려면, model.inputsNone인지 확인합니다. None이 아니라면 잘 정의된 입력입니다. 입력 형태들은 모델이 .fit, .evaluate, .predict에서 쓰이거나 모델을 호출 (model(inputs)) 할 때 자동으로 정의됩니다.

예시를 살펴봅시다:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # 오류! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f4092ee5160>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f4092ee5160>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f4092ee5550>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f4092ee5550>, because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets