![]() | ![]() | ![]() | ![]() |
概要概要
トレーニング中にモデルを保存してロードするのが一般的です。 kerasモデルを保存およびロードするためのAPIには、高レベルAPIと低レベルAPIの2つのセットがあります。このチュートリアルでは、使用しているときにSavedModel APIを使用する方法を示しtf.distribute.Strategy
。 SavedModelとシリアル化全般については、保存済みモデルガイドとKerasモデルシリアル化ガイドをお読みください。簡単な例から始めましょう:
依存関係のインポート:
import tensorflow_datasets as tfds
import tensorflow as tf
tf.distribute.Strategy
を使用してデータとモデルを準備します。
mirrored_strategy = tf.distribute.MirroredStrategy()
def get_data():
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
モデルをトレーニングします。
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
Epoch 1/2 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 938/938 [==============================] - 4s 4ms/step - loss: 0.2095 - sparse_categorical_accuracy: 0.9386 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 2/2 938/938 [==============================] - 2s 3ms/step - loss: 0.0730 - sparse_categorical_accuracy: 0.9787 <tensorflow.python.keras.callbacks.History at 0x7f7470042b38>
モデルを保存してロードします
使用する簡単なモデルができたので、APIの保存/読み込みを見てみましょう。利用可能なAPIのセットは2つあります。
- 高レベルの
model.save
およびtf.keras.models.load_model
- 低レベルの
tf.saved_model.save
およびtf.saved_model.load
Keras API
KerasAPIを使用してモデルを保存およびロードする例を次に示します。
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets
tf.distribute.Strategy
なしでモデルを復元しtf.distribute.Strategy
:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 2s 2ms/step - loss: 0.0539 - sparse_categorical_accuracy: 0.9838 Epoch 2/2 938/938 [==============================] - 2s 2ms/step - loss: 0.0381 - sparse_categorical_accuracy: 0.9884 <tensorflow.python.keras.callbacks.History at 0x7f74d333f780>
モデルを復元した後は、保存する前にすでにコンパイルされているため、 compile()
再度呼び出す必要がなくても、モデルのトレーニングを続行できます。モデルは、TensorFlowの標準のSavedModel
プロト形式で保存されます。詳細については、 saved_model
形式のガイドを参照してください。
次に、モデルをロードし、 tf.distribute.Strategy
を使用してtf.distribute.Strategy
ます。
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 9s 10ms/step - loss: 0.0530 - sparse_categorical_accuracy: 0.9844 Epoch 2/2 938/938 [==============================] - 9s 9ms/step - loss: 0.0388 - sparse_categorical_accuracy: 0.9882
ご覧のとおり、ロードはtf.distribute.Strategy
期待どおりにtf.distribute.Strategy
ます。ここで使用される戦略は、保存する前に使用される戦略と同じである必要はありません。
tf.saved_model
API
それでは、低レベルのAPIを見てみましょう。モデルの保存は、kerasAPIに似ています。
model = get_model() # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
ロードはtf.saved_model.load()
でtf.saved_model.load()
できます。ただし、これは下位レベルのAPIであるため(したがって、使用例の範囲が広くなります)、Kerasモデルは返されません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。例えば:
DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
ロードされたオブジェクトには、それぞれがキーに関連付けられた複数の関数が含まれる場合があります。 "serving_default"
は、保存されたKerasモデルを使用した推論関数のデフォルトキーです。この関数で推論を行うには:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[-2.46878400e-01, -2.84028575e-02, 4.34195548e-02, 8.65758881e-02, -5.50181568e-02, -2.26117969e-02, -8.18806365e-02, 1.60868585e-01, 7.05277026e-02, -2.11526364e-01], [-2.04405725e-01, -2.38965377e-02, 1.06097549e-01, 1.15776211e-02, -5.68305999e-02, 7.61558264e-02, -2.36685127e-02, 6.12710230e-02, 6.85455352e-02, -2.04084530e-01], [-1.70060426e-01, 6.82905912e-02, -2.54967008e-02, 1.27377272e-01, -4.24135383e-03, -1.15118716e-02, 1.65115029e-01, 1.64797649e-01, 8.41001868e-02, -2.60865986e-01], [-1.24608956e-01, 7.05861971e-02, 4.76837084e-02, 9.51382518e-02, -1.36017501e-02, 9.53883678e-02, -2.60323286e-04, 1.26946449e-01, -9.98851806e-02, 6.01550192e-02], [-8.42214674e-02, -4.93131615e-02, -5.85474074e-04, -3.79234888e-02, -6.78482801e-02, 9.56373289e-02, 4.69041206e-02, 8.55031833e-02, 9.31831449e-02, -1.40825540e-01], [-1.46941900e-01, 1.22972876e-02, 5.79140112e-02, -7.50405565e-02, 6.13511279e-02, 1.14746153e-01, 3.54535617e-02, 2.55915433e-01, 7.26796240e-02, -1.99857190e-01], [-2.07879156e-01, 1.83034241e-02, 1.57775074e-01, 6.06807172e-02, -1.75382420e-02, 1.33817732e-01, 1.36331618e-01, 2.02472329e-01, 3.72610986e-02, -1.31865010e-01], [-9.93705392e-02, 6.03869818e-02, -4.28698361e-02, 6.31842762e-04, 8.84034038e-02, 6.72685653e-02, -2.09506359e-02, 1.97081745e-01, 7.39021823e-02, -1.64300233e-01], [-9.71228778e-02, 5.48233166e-02, 1.38393641e-02, -7.14895800e-02, -3.87909710e-02, 8.45830888e-04, -3.62640694e-02, 1.64835989e-01, 5.04231751e-02, -2.07461655e-01], [-2.92240772e-02, 1.45425312e-02, 5.74428178e-02, -1.34241190e-02, -1.80013701e-02, 7.78546855e-02, -8.48746449e-02, 9.98296142e-02, 6.38790280e-02, -5.32845445e-02], [-1.76605240e-01, -1.42511949e-01, 1.39559209e-01, -2.00123414e-02, -6.44349307e-02, -4.56911251e-02, 2.01093405e-03, 1.59898788e-01, 1.95391588e-02, -1.61375850e-01], [-1.58091724e-01, 6.25609234e-03, 2.12391287e-01, -1.39106885e-01, -4.78955358e-02, 7.36434534e-02, 7.29984716e-02, 2.28351891e-01, 1.23042218e-01, -2.22285807e-01], [-6.63312748e-02, -5.25613949e-02, 3.88407931e-02, 4.74876724e-02, -3.56937200e-02, 1.11578718e-01, -8.47167745e-02, 1.54049486e-01, 8.42248723e-02, -9.11155120e-02], [-1.49975002e-01, -1.69416200e-02, 2.03275681e-03, 3.08024809e-02, -1.28081590e-02, 1.18468963e-01, -7.31947795e-02, 2.10938901e-01, 5.79604283e-02, -1.06384277e-01], [-2.44300172e-01, 6.77020177e-02, 1.61827058e-02, 9.77846682e-02, -2.14450657e-02, 8.76296014e-02, 1.55660659e-02, 2.56645411e-01, -6.94077387e-02, 1.82542913e-02], [-3.24441910e-01, 2.83106230e-02, 1.15296148e-01, -6.49778843e-02, -3.93164232e-02, 2.09751099e-01, 1.58456087e-01, 2.03075439e-01, 1.45919517e-01, -8.07187557e-02], [-1.77742794e-01, -3.47406045e-02, 6.37909994e-02, 5.72632812e-02, -1.67798519e-01, -9.77907851e-02, -6.33480251e-02, 5.98776974e-02, -1.48319647e-01, -3.26665044e-02], [-1.92516297e-02, -4.32192907e-02, 9.45950896e-02, -1.24730960e-01, 3.15439701e-03, 7.49434829e-02, 1.42610222e-01, 1.64739519e-01, 1.35794416e-01, -2.33872890e-01], [-9.74408463e-02, -4.51198146e-02, -7.16688111e-02, 1.52820855e-01, 3.08901221e-02, -8.07915181e-02, -8.59454572e-02, 1.73750147e-01, -4.14928459e-02, -1.02175683e-01], [-1.79451153e-01, 7.97335058e-02, 6.08496368e-02, -8.74251127e-05, 1.40254274e-02, 7.78948367e-02, 1.22523680e-02, 1.38402849e-01, -2.44962424e-03, -8.56248587e-02], [-7.16196820e-02, -3.66464853e-02, -1.97902359e-02, -3.42466384e-02, 1.01994909e-02, 8.11903924e-02, 1.02423221e-01, 8.15625191e-02, 9.28392410e-02, -1.61639646e-01], [-1.29672050e-01, -9.39578265e-02, -3.77402268e-02, -5.66408038e-03, 2.01772340e-02, -5.53961843e-04, 1.12603299e-01, 1.18293904e-01, 7.59286210e-02, -1.05032220e-01], [ 3.13648432e-02, 2.04140544e-02, 8.68844241e-02, 8.54840502e-03, -3.24598253e-02, 7.13473856e-02, 1.01958007e-01, 1.58244759e-01, 4.33884151e-02, -1.56489074e-01], [-5.69176152e-02, -8.68148059e-02, 5.83150014e-02, -6.94776773e-02, -1.14257783e-01, 9.14709717e-02, -6.18093796e-02, 4.60445434e-02, 6.21100292e-02, -2.56335258e-01], [-1.00941956e-03, -9.87592638e-02, 1.59144640e-01, 2.46649459e-02, -1.47723123e-01, 3.34706903e-03, -1.25270292e-01, 7.13937655e-02, -3.65925357e-02, -2.86379248e-01], [-2.52649784e-01, -1.80219673e-02, 1.53900415e-01, -7.60671049e-02, -4.30139415e-02, 6.14799336e-02, 5.27559966e-02, 3.91793013e-01, 1.10363506e-01, -2.21582249e-01], [-1.04441456e-02, -5.70102595e-02, -5.45391962e-02, -6.66194037e-02, 3.30452994e-02, 4.31669690e-03, -1.39387622e-02, 1.50821537e-01, 7.82721266e-02, -1.13290384e-01], [-1.50469467e-01, -1.50829509e-01, 1.37116134e-01, -7.71817416e-02, -1.22132301e-01, 8.29393342e-02, 7.44771212e-03, 1.10161960e-01, 5.23409843e-02, -1.67824954e-01], [-1.67705536e-01, -1.61053427e-02, 3.56741399e-02, -8.12948644e-02, -2.15860698e-02, 7.68682212e-02, 3.90296578e-02, 8.14016312e-02, 1.20665669e-01, -5.40915243e-02], [-1.74987361e-01, 5.39990142e-03, 7.59589747e-02, 1.13510445e-01, -3.19063663e-02, -5.98092973e-02, -4.05801088e-02, 2.37588376e-01, -6.73733801e-02, -1.72320567e-02], [-1.80301860e-01, 2.00746767e-02, -7.40496814e-03, 8.36828053e-02, 9.17709470e-02, 1.46025598e-01, -2.91051138e-02, 2.14360297e-01, -3.91696244e-02, -1.15331344e-01], [-7.45102018e-02, 3.96583155e-02, 8.10021013e-02, 1.56707764e-02, -2.35380158e-02, 1.56681970e-01, -1.12800300e-02, 3.64681214e-01, 1.12793013e-01, -9.20613408e-02], [-1.10700965e-01, -3.84411961e-03, 7.15886354e-02, -5.16710430e-03, -2.68637538e-02, -4.64520939e-02, -1.02423206e-01, 1.41418934e-01, 1.36580504e-02, -2.16841191e-01], [-1.03602912e-02, -1.36248600e-02, -8.44807327e-02, -3.93018406e-03, 6.54329583e-02, -1.54229663e-02, -9.10714716e-02, 1.13576502e-02, 6.24551401e-02, -1.10215969e-01], [-1.64637700e-01, -4.25843447e-02, -6.63272589e-02, 1.01544857e-02, 9.00160298e-02, 1.41169682e-01, 9.43019092e-02, 1.50300652e-01, 1.17022656e-01, -2.61101604e-01], [-2.96755701e-01, 1.48339659e-01, 5.29592186e-02, 4.51779664e-02, -6.84008598e-02, 1.29287004e-01, 1.34066977e-02, 1.68794006e-01, -1.53631158e-02, -1.40826374e-01], [-2.27824658e-01, -3.58637236e-02, 7.98013210e-02, -2.93148141e-02, -1.29889801e-01, 1.07304119e-02, 6.16377033e-02, 2.38016129e-01, 1.68460131e-01, -2.78131723e-01], [-1.97686747e-01, -1.20533034e-01, 1.91476271e-02, -2.50333622e-02, -1.20231688e-01, -1.43363982e-01, -5.45644462e-02, 1.13663480e-01, -9.71207619e-02, -7.38224685e-02], [-1.21181801e-01, -9.18156952e-02, 1.72619522e-02, 7.20846877e-02, -5.00237271e-02, -7.88232982e-02, -2.75398232e-02, 9.42765027e-02, -8.18064660e-02, -4.43772227e-02], [-2.12152809e-01, -1.05831539e-02, 1.12541884e-01, 3.79703306e-02, -4.97136004e-02, -8.26531351e-02, 4.28089425e-02, 2.72401571e-01, -9.41082910e-02, -8.25358368e-03], [-2.12490350e-01, 5.10787666e-02, -4.91231680e-03, 1.71558380e-01, 8.33496898e-02, 8.03120583e-02, 5.97136915e-02, 2.78716445e-01, -5.66011816e-02, -7.99765587e-02], [-2.45497763e-01, -5.21367639e-02, 1.77163050e-01, 8.67958441e-02, -1.33168459e-01, 9.83412005e-03, -1.34591311e-01, 1.48744047e-01, -6.65533617e-02, -1.07505932e-01], [-1.36525869e-01, -5.12802340e-02, 2.54329219e-02, 8.01228657e-02, -3.24120894e-02, -6.36913255e-03, -7.75915161e-02, 1.81387305e-01, 6.72850609e-02, -1.06104709e-01], [-8.19087848e-02, -6.67821616e-02, 1.09396182e-01, -8.99944529e-02, -1.08385280e-01, 6.29347712e-02, 7.26154894e-02, 1.68957621e-01, 1.90485001e-01, -2.60798335e-01], [-1.76897705e-01, 4.90825251e-02, 2.94402167e-02, -2.41212249e-02, 3.94896790e-02, 1.18754521e-01, 1.69773921e-02, 1.10196158e-01, 7.08303824e-02, -6.86142594e-02], [-1.29656106e-01, -8.14089552e-02, 1.14682741e-01, -1.32834181e-01, -1.49253279e-01, -2.83164792e-02, 3.45680863e-04, 2.52322882e-01, 2.89388448e-02, -2.79281288e-01], [-1.10502213e-01, 1.07094124e-01, 3.24486196e-02, 7.70951509e-02, -6.27939776e-02, 1.68845624e-01, -1.44310594e-01, 1.45337492e-01, 2.03377791e-02, -5.04231378e-02], [-2.66523331e-01, -7.49082193e-02, 1.91363335e-01, -6.39847219e-02, -1.04055285e-01, 8.31385702e-02, 8.82939398e-02, 1.99207246e-01, 5.35239354e-02, -2.60884434e-01], [-1.35722771e-01, 3.94147262e-02, -6.39424995e-02, 1.39283150e-01, 5.37211001e-02, -6.34303223e-03, -1.70467123e-01, 2.55692095e-01, -7.66103566e-02, -6.90388680e-02], [-1.07885860e-01, 2.30858717e-02, 8.21547359e-02, -3.12240291e-02, -9.89983678e-02, 7.22398609e-02, -4.08478230e-02, 8.69123414e-02, 4.48577479e-02, -6.41947538e-02], [-2.28321850e-02, -3.88411283e-02, 1.47033811e-01, -2.35385150e-01, -9.87000838e-02, 6.44287840e-02, -1.87633559e-02, 1.17905587e-01, 9.70625877e-02, -2.46781930e-01], [-8.77917856e-02, -1.64044406e-02, 7.53755122e-02, -8.24043527e-04, -7.77238905e-02, 1.16269790e-01, -1.00877963e-01, 8.79124254e-02, 3.39440927e-02, -5.94997481e-02], [-1.41677827e-01, -1.40151009e-02, 8.84927809e-04, 1.03166051e-01, -1.66242346e-02, 2.62837298e-02, -1.33589238e-01, 1.65735006e-01, 3.65820900e-02, -1.46895535e-02], [-1.61557034e-01, 5.66626638e-02, -1.61597617e-02, 2.58595943e-02, 3.39905620e-02, 1.01104185e-01, -3.71510983e-02, 1.20341092e-01, 3.26242894e-02, -4.07250933e-02], [-2.17516154e-01, 7.85727724e-02, 9.79433060e-02, 6.97179586e-02, 4.95264679e-02, 1.92503840e-01, -4.96265218e-02, 1.99431688e-01, -5.32730669e-03, -2.50038877e-02], [-1.35356426e-01, -6.96291253e-02, 3.92658785e-02, -9.86322537e-02, -4.20986377e-02, 9.87840891e-02, 9.67663303e-02, 1.76262826e-01, 9.44406465e-02, -2.23472387e-01], [-1.25066608e-01, 7.71146417e-02, 4.02672291e-02, -2.05352344e-02, 3.11498251e-02, 9.64582711e-02, -5.39951548e-02, 2.29750067e-01, 1.61451437e-02, -5.41997403e-02], [-1.93750665e-01, -3.56721133e-03, -1.50568932e-02, 1.78796798e-02, 8.33508372e-03, -1.18013099e-02, -5.35021350e-02, 2.02244624e-01, 3.02494057e-02, -1.20312274e-01], [-2.62067527e-01, 2.36408859e-02, 5.58489896e-02, 1.75756812e-01, -2.75299139e-02, 3.48872915e-02, 5.41301072e-03, 3.15880209e-01, -5.74782193e-02, 7.00992346e-03], [-2.76674211e-01, -2.08131559e-02, -1.26259401e-02, 7.77718723e-02, -1.54706314e-01, 1.31996438e-01, 2.20355690e-02, 5.61908968e-02, 3.73308063e-02, -1.17717944e-01], [-1.59806639e-01, 1.20503023e-01, -4.36934829e-03, 1.16428092e-01, 5.47975339e-02, 1.25162587e-01, 4.78192419e-02, 1.28253624e-01, 7.34245628e-02, -1.80039048e-01], [-2.67963678e-01, 6.00077920e-02, 1.13472804e-01, 7.52071738e-02, -6.40357211e-02, 1.03171021e-01, 1.48901194e-01, 1.97019696e-01, 3.76104042e-02, -1.68720663e-01], [-2.01240778e-01, 2.47026011e-02, 3.10055390e-02, -8.58910009e-03, -8.49897265e-02, -7.54948407e-02, -9.39515531e-02, 1.34306327e-01, -1.71037674e-01, -5.76597378e-02], [-5.20152375e-02, 6.59879148e-02, -3.30656916e-02, 9.97125208e-02, 3.56362388e-02, 1.26982957e-01, -2.69417539e-02, 1.59046397e-01, 1.10872082e-01, -1.84650719e-01]], dtype=float32)>}
分散してロードして推論を行うこともできます。
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# Calling the function in a distributed manner
for batch in dist_predict_dataset:
another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
復元された関数の呼び出しは、保存されたモデル(予測)のフォワードパスにすぎません。ロードされた関数のトレーニングを続行したい場合はどうなりますか?または、ロードされた関数をより大きなモデルに埋め込みますか?一般的な方法は、このロードされたオブジェクトをKerasレイヤーにラップしてこれを実現することです。幸いなことに、 TFハブにはこの目的のためのhub.KerasLayerがあります。
import tensorflow_hub as hub
def build_model(loaded):
x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
# Wrap what's loaded to a KerasLayer
keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
model = tf.keras.Model(x, keras_layer)
return model
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
model = build_model(loaded)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) Epoch 1/2 938/938 [==============================] - 2s 3ms/step - loss: 0.1981 - sparse_categorical_accuracy: 0.9412 Epoch 2/2 938/938 [==============================] - 2s 3ms/step - loss: 0.0655 - sparse_categorical_accuracy: 0.9804
ご覧のとおり、 hub.KerasLayer
は、 hub.KerasLayer
tf.saved_model.load()
からロードされた結果を、別のモデルの構築に使用できるhub.KerasLayer
ラップします。これは、転移学習に非常に役立ちます。
どのAPIを使用する必要がありますか?
保存のために、kerasモデルを使用している場合は、ほとんどの場合、 model.save()
のmodel.save()
APIを使用することをお勧めします。保存しているものがKerasモデルでない場合は、低レベルのAPIが唯一の選択肢です。
ロードの場合、使用するAPIは、ロードAPIから取得するものによって異なります。 tf.saved_model.load()
モデルを取得できない(または取得したくない)場合は、 tf.saved_model.load()
使用します。それ以外の場合は、 tf.keras.models.load_model()
使用します。 Kerasモデルを保存した場合にのみ、Kerasモデルを取り戻すことができることに注意してください。
APIを組み合わせて使用することができます。 model.save
してmodel.save
モデルを保存し、低レベルtf.saved_model.load
使用して非model.save
モデルをロードtf.saved_model.load
ます。
model = get_model()
# Saving the model using Keras's save() API
model.save(keras_model_path)
another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
ローカルデバイスからの保存/読み込み
たとえばクラウドTPUを使用して、リモートで実行しているときにローカルioデバイスから保存およびロードする場合、オプションexperimental_io_device
使用してioデバイスをlocalhostに設定する必要があります。
model = get_model()
# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)
# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
警告
特別な場合は、明確に定義された入力がないKerasモデルがある場合です。たとえば、Sequentialモデルは、入力形状なしで作成できます( Sequential([Dense(3), ...]
)。サブクラス化されたモデルにも、初期化後に明確に定義された入力がありません。この場合、保存と読み込みの両方で低レベルのAPIを使用しないと、エラーが発生します。
モデルに明確に定義された入力があるかどうかを確認するには、 model.inputs
がNone
かどうかを確認します。 None
でない場合は、すべて問題ありませNone
。入力形状は、モデルが.fit
、 .evaluate
、 .predict
で使用されるとき、またはモデルを呼び出すときに自動的に定義されます( model(inputs)
)。
次に例を示します。
class SubclassedModel(tf.keras.Model):
output_name = 'output_layer'
def __init__(self):
super(SubclassedModel, self).__init__()
self._dense_layer = tf.keras.layers.Dense(
5, dtype=tf.dtypes.float32, name=self.output_name)
def call(self, inputs):
return self._dense_layer(inputs)
my_model = SubclassedModel()
# my_model.save(keras_model_path) # ERROR!
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets