![]() |
![]() |
![]() |
![]() |
개요
훈련 도중 모델을 저장하고 불러오는 것은 흔히 일어나는 일입니다. 케라스 모델을 저장하고 불러오기 위한 API에는 high-level API와 low-level API, 두 가지가 있습니다. 이 튜토리얼은 tf.distribute.Strategy
를 사용할 때 어떻게 SavedModel APIs를 사용할 수 있는지 보여줍니다. SavedModel과 직렬화에 관한 일반적인 내용을 학습하려면, saved model guide와 Keras model serialization guide를 읽어보는 것을 권장합니다. 간단한 예로 시작해보겠습니다:
필요한 패키지 가져오기:
!pip install -q tf-nightly
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
tf.distribute.Strategy
를 사용하는 모델과 데이터 준비하기:
mirrored_strategy = tf.distribute.MirroredStrategy()
def get_data():
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
return model
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
모델 훈련시키기:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 14s 11ms/step - loss: 0.3931 - accuracy: 0.8903 Epoch 2/2 938/938 [==============================] - 9s 9ms/step - loss: 0.0708 - accuracy: 0.9794 <tensorflow.python.keras.callbacks.History at 0x7f41006be9b0>
모델 저장하고 불러오기
이제 사용할 모델을 가지고 있으므로 API를 이용해 모델을 저장하고 불러오는 방법에 대해 살펴봅시다. 두 가지 API를 사용 할 수 있습니다:
- 고수준 케라스
model.save
와tf.keras.models.load_model
- 저수준 케라스
tf.saved_model.save
와tf.saved_model.load
케라스 API
케라스 API들을 이용해 모델을 저장하고 불러오는 예를 소개합니다.
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path) # save()는 전략 범위를 벗어나 호출되어야 합니다.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2289: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`Model.state_updates` will be removed in a future version. ' /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:1377: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`layer.updates` will be removed in a future version. ' INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets
tf.distribute.Strategy
없이 모델 복원시키기:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 9s 9ms/step - loss: 0.0469 - accuracy: 0.9858 Epoch 2/2 938/938 [==============================] - 8s 9ms/step - loss: 0.0325 - accuracy: 0.9899 <tensorflow.python.keras.callbacks.History at 0x7f40dc32f080>
모델을 복원시킨 후에는 compile()
이 이미 저장되기 전에 컴파일 되기 때문에, compile()
을 다시 호출하지 않고도 모델 훈련을 계속 할 수 있습니다. 그 모델은 텐서플로 표준 SavedModel
의 프로토 타입에 저장됩니다. 더 많은 정보를 원한다면, the guide to saved_model
format를 참고하세요.
tf.distribute.strategy
의 범위를 벗어나서 model.save()
방법을 호출하는 것은 중요합니다. 범위 안에서 호출하는 것은 지원하지 않습니다.
이제 모델을 불러와서 tf.distribute.Strategy
를 사용해 훈련시킵니다:
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 9s 9ms/step - loss: 0.0481 - accuracy: 0.9850 Epoch 2/2 938/938 [==============================] - 9s 9ms/step - loss: 0.0323 - accuracy: 0.9901
위에서 볼 수 있듯이, 불러오기는 tf.distribute.Strategy
에서 예상한대로 작동합니다. 여기서 사용된 전략은 이전에 사용된 전략과 같지 않아도 됩니다.
tf.saved_model
형 API
이제 저수준 API에 대해서 살펴봅시다. 모델을 저장하는 것은 케라스 API와 비슷합니다:
model = get_model() # 새 모델 얻기
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
tf.saved_model.load()
로 불러올 수 있습니다. 그러나 저수준 단계의 API이기 때문에 (따라서 더 넓은 사용범위를 갖습니다), 케라스 모델을 반환하지 않습니다. 대신, 추론하기 위해 사용될 수 있는 기능들을 포함한 객체를 반환합니다. 예를 들어:
DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
불러와진 객체는 각각 키와 관련된 채, 여러 기능을 포함할 수 있습니다. "serving_default"
는 저장된 케라스 모델이 있는 추론 기능을 위한 기본 키입니다. 이 기능을 이용하여 추론합니다:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[-2.31710002e-01, 4.69997153e-02, -5.89058809e-02, -7.89453760e-02, 5.00574857e-02, 5.67903817e-02, -1.23740345e-01, -1.68794781e-01, -9.27029625e-02, -2.51288135e-02], [-5.54974042e-02, 4.41712886e-02, 4.67255451e-02, -3.14179927e-01, -8.12758040e-03, -1.25043824e-01, 1.11420020e-01, -1.41085178e-01, -2.66745295e-02, 1.81214467e-01], [-1.52560011e-01, 1.73152894e-01, 1.51429757e-01, -1.92452252e-01, 6.56868564e-04, -5.58338277e-02, -1.48804402e-02, -1.08949229e-01, -2.73743093e-01, 2.33208835e-01], [-1.47220954e-01, 4.86791804e-02, 4.00493629e-02, -1.38356254e-01, -1.55049503e-01, 3.04028448e-02, -2.39117034e-02, -1.78863198e-01, -1.02356449e-01, 1.63152605e-01], [-3.83402146e-02, 8.73908177e-02, -2.64412500e-02, -1.59144387e-01, 2.08970089e-03, -3.62645015e-02, -2.08238680e-02, -5.04877605e-02, -1.03073210e-01, 4.57865223e-02], [-1.55467838e-01, 1.53151318e-01, 8.46455526e-03, -2.92723507e-01, -1.38505787e-01, 2.46969499e-02, -4.37425300e-02, -2.59471953e-01, -9.27083716e-02, 1.50457397e-01], [ 7.47342920e-03, 6.51581958e-02, 2.24648677e-02, -3.74905109e-01, 7.72060007e-02, -2.05429822e-01, 1.42279223e-01, -2.03866616e-01, -1.43216476e-01, 1.50174573e-01], [-2.21392199e-01, 3.55903916e-02, -6.64935857e-02, -1.44605845e-01, -9.43055376e-02, 4.91402186e-02, -3.01179551e-02, -2.77260333e-01, -8.63499865e-02, 8.49083737e-02], [-5.22387400e-02, 1.85074970e-01, -4.93934825e-02, -2.32034683e-01, -5.27238473e-02, -9.60701406e-02, 1.04433589e-01, -1.57325357e-01, -8.58974010e-02, 2.14043185e-01], [-1.97679460e-01, 6.41728342e-02, 8.06091875e-02, -1.11604221e-01, -1.03174940e-01, 9.69565138e-02, -7.05495402e-02, -1.35097876e-01, 2.88576055e-02, 1.62075996e-01], [-1.50621310e-01, 8.89200866e-02, -7.73850307e-02, -2.46047735e-01, -7.31822327e-02, -1.94585882e-03, 4.40529808e-02, -2.12117821e-01, -8.78427178e-02, 1.60316110e-01], [ 1.90468263e-02, 1.33927181e-01, 1.42686572e-02, -3.14610422e-01, -9.46181118e-02, -3.81766558e-02, 1.44467920e-01, -2.50842303e-01, -4.85255271e-02, 9.90096182e-02], [-6.63781390e-02, 1.77936614e-01, -2.99972966e-02, -9.23760459e-02, -4.85758763e-03, -1.31934509e-01, 9.16874483e-02, -2.11921008e-03, -1.27245948e-01, 1.21350765e-01], [-1.94299594e-01, 1.20598502e-01, -6.00222573e-02, -2.14238107e-01, -6.61890209e-03, 6.19555488e-02, -4.54728752e-02, -2.22066879e-01, -6.29875958e-02, 5.50101288e-02], [-1.89868703e-01, 1.23439826e-01, -5.72721995e-02, -1.86088681e-01, 3.22347730e-02, 4.71101515e-02, 1.32112494e-02, -2.46812001e-01, -1.65517509e-01, 1.70899794e-01], [-2.37572849e-01, 7.13197440e-02, -1.10374898e-01, -2.92222261e-01, -2.01209523e-02, -1.52642268e-03, -8.65244940e-02, -2.74755478e-01, -1.67763025e-01, 1.86234102e-01], [-2.34291941e-01, 7.28301629e-02, -7.74543583e-02, -3.01767707e-01, -3.98506373e-02, 1.46720614e-02, -3.79977487e-02, -2.87194639e-01, -1.63134843e-01, 1.58531561e-01], [-1.14712827e-01, 2.31486917e-01, -1.11679450e-01, -1.80989966e-01, -6.16163909e-02, -6.43817782e-02, 7.69487303e-03, -2.10871249e-01, -7.00942352e-02, 1.38534650e-01], [-7.92532116e-02, 7.73300529e-02, 1.22292396e-02, -1.39936939e-01, -1.34650869e-02, -6.53443411e-02, 4.82317656e-02, -1.20916925e-01, -6.96138963e-02, 1.05453297e-01], [-1.63892403e-01, 1.84797551e-02, -1.92084417e-01, -2.49666631e-01, -2.61558555e-02, -7.30722621e-02, 7.79571831e-02, -3.07394713e-01, -2.03693733e-01, 1.51574135e-01], [-2.17261195e-01, 2.29646742e-01, 4.66779955e-02, -1.18903786e-01, 3.22985947e-02, 2.45849416e-02, -4.92762327e-02, -2.28806779e-01, -1.68327525e-01, 1.23770081e-01], [-1.62324741e-01, 1.97496653e-01, -4.72196564e-02, -7.94636458e-02, -3.82318087e-02, 6.89196661e-02, 6.70473129e-02, -1.38797700e-01, -1.29535258e-01, 7.66131133e-02], [-2.55516283e-02, 2.01869339e-01, -9.33012292e-02, -1.05885521e-01, -8.76843482e-02, -2.48237029e-01, 1.41079038e-01, 1.33428704e-02, -1.26833051e-01, 2.09878638e-01], [-8.15958306e-02, 1.12188198e-01, -4.23568934e-02, -2.66130358e-01, -5.01424633e-02, -1.05463155e-01, 1.02032656e-02, -4.05757539e-02, 2.25119526e-03, 1.29559264e-01], [-3.66339386e-02, 2.90437024e-02, -1.27368569e-02, -1.77035138e-01, -1.71542689e-01, -1.30576864e-01, 5.47315702e-02, 8.96204263e-03, -5.72763458e-02, 2.34254017e-01], [-1.56500295e-01, 1.53565541e-01, -2.28636363e-03, -6.62139207e-02, -3.27328742e-02, 2.93045882e-02, 8.53747949e-02, -1.80385932e-01, -5.65336086e-02, 1.08000956e-01], [-1.65389195e-01, 5.24501018e-02, -4.41066660e-02, -2.00338721e-01, -1.96185365e-01, 5.90052642e-02, -1.26092210e-01, -1.47565231e-01, -4.03641611e-02, 1.30440056e-01], [-8.40880573e-02, 1.74001738e-01, 7.26161972e-02, -2.40350902e-01, -5.72482757e-02, 3.98018137e-02, 1.17004618e-01, -1.99786767e-01, -1.22575566e-01, 1.28435671e-01], [-2.14175552e-01, 1.20554797e-01, -8.64092037e-02, -1.62089050e-01, -5.25108278e-02, 5.40847518e-02, -6.55080006e-02, -1.35221720e-01, -1.54621825e-01, 9.57419202e-02], [-2.01131046e-01, 1.06524661e-01, 1.80122983e-02, -1.79245383e-01, -7.35889599e-02, 1.68863423e-02, 1.57249700e-02, -1.22766688e-01, -2.22174674e-01, 1.55999824e-01], [-1.73681706e-01, 2.71402802e-02, -6.36542439e-02, -2.33296618e-01, 5.99484937e-03, -8.22732747e-02, 4.46838280e-03, -3.60670090e-01, -1.90715417e-01, 1.74717352e-01], [-1.50259972e-01, 1.19764805e-01, -1.30733386e-01, -1.24432303e-01, 5.68290241e-02, -9.72339660e-02, 5.98696843e-02, -1.29928470e-01, -1.39333844e-01, 1.54494181e-01], [-2.02477291e-01, 6.93696439e-02, -1.64154377e-02, -2.08582237e-01, -4.06453758e-02, 6.82849586e-02, 6.64708838e-02, -3.71645361e-01, -7.36382082e-02, 1.33133501e-01], [-2.91123018e-02, 1.13083079e-01, -5.78429513e-02, -1.17793962e-01, 3.83841023e-02, -6.84297532e-02, -7.53869042e-02, -1.19341440e-01, -4.64585349e-02, 3.49517837e-02], [-5.77876680e-02, 1.11637473e-01, -6.39148988e-03, -1.79888755e-01, 5.42832632e-03, -3.55358124e-02, -1.36482995e-02, -1.63555056e-01, 8.18812251e-02, 8.08460638e-03], [-2.27064252e-01, 1.60249267e-02, -6.07033782e-02, -3.98558080e-01, 7.34167695e-02, -5.85101917e-02, -1.54307717e-02, -2.20064804e-01, -2.14347333e-01, 2.11124346e-01], [-1.12135457e-02, 2.18757659e-01, -2.20545866e-02, -7.66985267e-02, 6.23554625e-02, -1.23242266e-01, 7.94550553e-02, -1.17306106e-01, -1.00016348e-01, 7.35235736e-02], [-1.80582091e-01, 7.76527375e-02, -9.21080783e-02, -3.21816146e-01, -7.20392838e-02, -2.17508823e-02, 2.00027265e-02, -2.27342859e-01, -1.80210322e-01, 2.26745576e-01], [-1.26630962e-01, 1.67847186e-01, -4.95181270e-02, -1.84193403e-01, 1.53434277e-03, -9.49611515e-02, 3.41531672e-02, -1.75902978e-01, -7.35643804e-02, 1.04903586e-01], [-2.41142154e-01, 6.06171973e-02, -6.76750392e-02, -1.87891945e-01, -3.97866219e-02, 3.21487859e-02, 5.36228605e-02, -2.19879225e-01, -1.72935262e-01, 1.44516155e-01], [-3.01095489e-02, 1.34722009e-01, -1.04549803e-01, -2.10784599e-01, 1.03664540e-01, -1.65217251e-01, 1.22614846e-01, -1.58423603e-01, -1.27167642e-01, 5.74181266e-02], [-1.88275278e-01, 3.01557928e-01, 7.80426115e-02, -1.58747271e-01, -3.58056873e-02, 4.67294604e-02, 1.33869842e-01, -1.95851907e-01, -1.37279749e-01, 1.60734072e-01], [-2.03269720e-01, 1.67486995e-01, 1.03996946e-02, -1.46267176e-01, -6.27027964e-03, 5.10909930e-02, 4.64825928e-02, -1.72956482e-01, -1.41186193e-01, 1.83070302e-01], [ 3.57957877e-04, 1.42964885e-01, 1.80653855e-03, -1.83541790e-01, -7.42951594e-03, -5.74298017e-02, -2.49897614e-02, -1.04883276e-01, -2.71756023e-01, 1.06882133e-01], [-2.62342334e-01, 1.42635107e-02, -2.48182025e-02, -1.63350329e-01, 2.29158327e-02, 9.97220501e-02, -2.65001059e-02, -1.65731251e-01, -1.82347804e-01, 9.64073464e-02], [-1.23267718e-01, 1.98346987e-01, 3.24245319e-02, -2.13430122e-01, -4.33703735e-02, -6.14091940e-02, 9.64062884e-02, -1.04775526e-01, -1.75305292e-01, 2.57383853e-01], [-1.68140769e-01, 1.36441827e-01, 1.08884182e-02, -9.64754224e-02, -2.50152405e-02, -9.46446229e-03, -5.93259335e-02, -1.07091017e-01, -8.96655619e-02, 1.96491078e-01], [-4.37381724e-03, 1.04307540e-01, 3.18750553e-02, -4.57363665e-01, 1.52269434e-02, -1.47090703e-01, 1.77273542e-01, -2.10172892e-01, -9.03630555e-02, 1.54496998e-01], [-1.35192886e-01, 2.94696316e-02, -1.39520437e-01, -2.53107995e-01, -1.16831122e-03, -1.32040858e-01, 1.00426756e-01, -1.57828271e-01, -2.56446660e-01, 1.84354991e-01], [-2.04489127e-01, 4.25634570e-02, 5.64568341e-02, -1.34944513e-01, -1.50223821e-01, 1.17938139e-01, -1.42907063e-02, -1.43584982e-01, -3.76214087e-02, 1.15251869e-01], [ 3.75817120e-02, 2.03622356e-01, -7.83001631e-02, -2.86667168e-01, 1.24724843e-01, -2.13014230e-01, 9.05650854e-02, -1.19515017e-01, -1.78294852e-01, 4.02432866e-02], [-5.27390689e-02, 2.08116159e-01, 1.04225648e-03, -1.05647072e-01, -3.11201848e-02, -8.59153792e-02, -4.90853004e-03, -1.39914704e-02, -1.43865094e-01, 1.31246611e-01], [-1.11234449e-01, 1.14173517e-01, 5.83788157e-02, -1.15305237e-01, 1.92318903e-03, -6.85511436e-03, 1.45244170e-02, -9.58509073e-02, -1.64572313e-01, 1.39922887e-01], [-1.54754862e-01, 1.12024948e-01, -1.05628334e-02, -2.01587588e-01, 2.00379938e-02, -1.57870334e-02, -5.81868887e-02, -1.31076097e-01, -1.26282051e-01, 4.07625437e-02], [-2.24229023e-01, 5.61567619e-02, -3.60122472e-02, -2.16083288e-01, 4.87294085e-02, -4.82604317e-02, -6.70362264e-02, -2.40707800e-01, -2.26278216e-01, 1.35704428e-01], [-1.75218150e-01, 2.40598410e-01, -2.64640544e-02, -2.43758753e-01, -4.32628170e-02, -1.73742871e-03, -1.46946525e-02, -2.60046333e-01, -1.43145313e-02, 1.53031796e-01], [-1.29570216e-01, 1.22143552e-01, -1.19005062e-01, -1.88355431e-01, 6.29518256e-02, -1.85053155e-01, 1.39260024e-01, -1.79315194e-01, -1.87190771e-01, 1.42009690e-01], [-2.56315738e-01, 1.26079351e-01, -5.71743995e-02, -2.76979387e-01, -2.74480134e-02, 2.68090926e-02, -2.47842458e-04, -3.72705132e-01, -2.36447796e-01, 1.77708328e-01], [-1.81314021e-01, 1.30829737e-01, -2.47730184e-02, -3.36653829e-01, -1.18225059e-02, -6.70068935e-02, -3.33150066e-02, -2.58504391e-01, -1.88907459e-01, 2.76795864e-01], [-1.96038321e-01, 1.03035390e-01, -3.73757035e-02, -2.18225539e-01, -9.72566567e-03, -3.53503712e-02, 7.75984153e-02, -2.07471773e-01, -6.30558878e-02, 1.32758439e-01], [-1.17597811e-01, 1.52661681e-01, 3.57742868e-02, -4.86104824e-02, 1.18510321e-01, -4.65476774e-02, 1.01888329e-02, -1.99031487e-01, -1.81028336e-01, 1.55515701e-01], [-1.11166604e-01, 1.70417845e-01, 4.58976217e-02, -3.81588846e-01, 8.12899694e-02, -1.33084998e-01, -3.65418009e-02, -2.46883601e-01, -2.13971853e-01, 2.10357890e-01], [-1.26382425e-01, 3.21384557e-02, -3.26598883e-02, -2.95765907e-01, -8.71679634e-02, -2.42257286e-02, 4.33042273e-02, -2.21485719e-02, -2.05287889e-01, 2.81905532e-01], [-1.02467999e-01, 1.65629908e-01, 2.73847040e-02, -1.95608556e-01, 4.34939824e-02, -5.75105920e-02, 7.93466158e-03, -2.23078132e-01, -1.31911159e-01, 1.82216957e-01]], dtype=float32)>}
또한 분산방식으로 불러오고 추론할 수 있습니다:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# 분산방식으로 기능 호출하기
for batch in dist_predict_dataset:
another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',) WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
복원된 기능을 호출하는 것은 단지 저장된 모델로의 정방향 패쓰입니다(예상하기에). 만약 계속해서 불러온 기능을 훈련시키고 싶다면 어떻게 하실건가요? 불러온 기능을 더 큰 모델에 내장시킬 건가요? 일반적인 방법은 이 불러온 객체를 케라스 층에 싸서(wrap) 달성하는 것입니다. 다행히도, TF Hub는 이 목적을 위해 hub.KerasLayer을 갖고 있으며, 다음과 같습니다.
import tensorflow_hub as hub
def build_model(loaded):
x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
# KerasLayer로 감싸기
keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
model = tf.keras.Model(x, keras_layer)
return model
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
model = build_model(loaded)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
model.fit(train_dataset, epochs=2)
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',) Epoch 1/2 938/938 [==============================] - 11s 9ms/step - loss: 0.4051 - accuracy: 0.8846 Epoch 2/2 938/938 [==============================] - 9s 9ms/step - loss: 0.0790 - accuracy: 0.9771
볼 수 있듯이, hub.KerasLayer
은 tf.saved_model.load()
로부터 불려온 결과를 또 다른 모델을 만드는데 사용될 수 있는 케라스 층으로 포장(wrap)합니다. 이것은 학습에 매우 유용합니다.
어떤 API를 사용해야 할까요?
저장에 관해서, 케라스 모델을 사용하는 경우, 케라스의 model.save()
API를 사용하는 것을 권장합니다. 저장하려는 모델이 케라스 모델이 아닌 경우, 더 낮은 단계의 API를 선택해야 합니다.
모델을 불러옴에 있어서, 어떤 API를 사용하느냐는 로딩 API에서 얻고자 하는 내용에 따라 결정됩니다. 케라스 모델을 가져올 수 없으면(또는 가져오고 싶지 않다면), tf.saved_model.load()
를 사용합니다. 그 외의 경우에는, tf.keras.models.load_model()
을 사용합니다. 케라스 모델을 저장한 경우에만 케라스 모델을 반환 받을 수 있다는 점을 유의하세요.
API들을 목적에 따라 혼합하고 짜 맞추는 것이 가능합니다. 케라스 모델을 model.save
와 함께 저장할 수 있고, 저수준 API인, tf.saved_model.load
로 케라스가 아닌 모델을 불러올 수 있습니다.
model = get_model()
# 케라스의 save() API를 사용하여 모델 저장하기
model.save(keras_model_path)
another_strategy = tf.distribute.MirroredStrategy()
# 저수준 API를 사용하여 모델 불러오기
with another_strategy.scope():
loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
주의사항
특별한 경우는 잘 정의되지 않은 입력을 갖는 케라스 모델을 갖고 있는 경우입니다. 예를 들어, 순차 모델은 입력 형태(Sequential([Dense(3), ...]
) 없이 만들 수 있습니다. 하위 분류된 모델들 또한 초기화 후에 잘 정의된 입력을 갖고 있지 않습니다. 이 경우 모델을 저장하고 불러올 시 저수준 API를 사용해야 하며, 그렇지 않으면 오류가 발생할 수 있습니다.
모델이 잘 정의된 입력을 갖는지 확인하려면, model.inputs
이 None
인지 확인합니다. None
이 아니라면 잘 정의된 입력입니다. 입력 형태들은 모델이 .fit
, .evaluate
, .predict
에서 쓰이거나 모델을 호출 (model(inputs)
) 할 때 자동으로 정의됩니다.
예시를 살펴봅시다:
class SubclassedModel(tf.keras.Model):
output_name = 'output_layer'
def __init__(self):
super(SubclassedModel, self).__init__()
self._dense_layer = tf.keras.layers.Dense(
5, dtype=tf.dtypes.float32, name=self.output_name)
def call(self, inputs):
return self._dense_layer(inputs)
my_model = SubclassedModel()
# my_model.save(keras_model_path) # 오류!
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f4092ee5160>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f4092ee5160>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f4092ee5550>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f4092ee5550>, because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets