Google I / O hoạt động trở lại từ ngày 18 đến 20 tháng 5! Đặt chỗ và xây dựng lịch trình của bạn Đăng ký ngay

Lưu và tải một mô hình bằng cách sử dụng chiến lược phân phối

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Tổng quat

Việc lưu và tải một mô hình trong quá trình đào tạo là điều phổ biến. Có hai bộ API để lưu và tải mô hình keras: API cấp cao và API cấp thấp. Hướng dẫn này trình bày cách bạn có thể sử dụng các API SavedModel khi sử dụng tf.distribute.Strategy . Để tìm hiểu về SavedModel và tuần tự hóa nói chung, vui lòng đọc hướng dẫn mô hình đã lưuhướng dẫn tuần tự hóa mô hình Keras . Hãy bắt đầu với một ví dụ đơn giản:

Nhập phụ thuộc:

import tensorflow_datasets as tfds

import tensorflow as tf

Chuẩn bị dữ liệu và mô hình bằng tf.distribute.Strategy :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Huấn luyện mô hình:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 4s 4ms/step - loss: 0.2095 - sparse_categorical_accuracy: 0.9386
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0730 - sparse_categorical_accuracy: 0.9787
<tensorflow.python.keras.callbacks.History at 0x7f7470042b38>

Lưu và tải mô hình

Bây giờ bạn đã có một mô hình đơn giản để làm việc, chúng ta hãy xem xét các API lưu / tải. Có hai bộ API có sẵn:

API Keras

Dưới đây là một ví dụ về lưu và tải một mô hình với các API Keras:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Khôi phục mô hình mà không có tf.distribute.Strategy :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0539 - sparse_categorical_accuracy: 0.9838
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0381 - sparse_categorical_accuracy: 0.9884
<tensorflow.python.keras.callbacks.History at 0x7f74d333f780>

Sau khi khôi phục mô hình, bạn có thể tiếp tục đào tạo về nó, thậm chí không cần gọi lại compile() , vì nó đã được biên dịch trước khi lưu. Mô hình được lưu ở định dạng proto SavedModel tiêu chuẩn của TensorFlow. Để biết thêm thông tin, vui lòng tham khảo hướng dẫn về định dạng saved_model .

Bây giờ để tải mô hình và đào tạo nó bằng cách sử dụng tf.distribute.Strategy :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 10ms/step - loss: 0.0530 - sparse_categorical_accuracy: 0.9844
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0388 - sparse_categorical_accuracy: 0.9882

Như bạn có thể thấy, việc tải hoạt động như mong đợi với tf.distribute.Strategy . Chiến lược được sử dụng ở đây không nhất thiết phải giống với chiến lược đã sử dụng trước khi lưu.

API tf.saved_model

Bây giờ chúng ta hãy xem xét các API cấp thấp hơn. Lưu mô hình tương tự như API keras:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

Việc tải có thể được thực hiện với tf.saved_model.load() . Tuy nhiên, vì nó là một API ở cấp độ thấp hơn (và do đó có phạm vi sử dụng rộng hơn), nó không trả về mô hình Keras. Thay vào đó, nó trả về một đối tượng có chứa các hàm có thể được sử dụng để suy luận. Ví dụ:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

Đối tượng được tải có thể chứa nhiều chức năng, mỗi chức năng được liên kết với một khóa. "serving_default" là khóa mặc định cho hàm suy luận với mô hình Keras đã lưu. Để thực hiện một suy luận với chức năng này:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.46878400e-01, -2.84028575e-02,  4.34195548e-02,
         8.65758881e-02, -5.50181568e-02, -2.26117969e-02,
        -8.18806365e-02,  1.60868585e-01,  7.05277026e-02,
        -2.11526364e-01],
       [-2.04405725e-01, -2.38965377e-02,  1.06097549e-01,
         1.15776211e-02, -5.68305999e-02,  7.61558264e-02,
        -2.36685127e-02,  6.12710230e-02,  6.85455352e-02,
        -2.04084530e-01],
       [-1.70060426e-01,  6.82905912e-02, -2.54967008e-02,
         1.27377272e-01, -4.24135383e-03, -1.15118716e-02,
         1.65115029e-01,  1.64797649e-01,  8.41001868e-02,
        -2.60865986e-01],
       [-1.24608956e-01,  7.05861971e-02,  4.76837084e-02,
         9.51382518e-02, -1.36017501e-02,  9.53883678e-02,
        -2.60323286e-04,  1.26946449e-01, -9.98851806e-02,
         6.01550192e-02],
       [-8.42214674e-02, -4.93131615e-02, -5.85474074e-04,
        -3.79234888e-02, -6.78482801e-02,  9.56373289e-02,
         4.69041206e-02,  8.55031833e-02,  9.31831449e-02,
        -1.40825540e-01],
       [-1.46941900e-01,  1.22972876e-02,  5.79140112e-02,
        -7.50405565e-02,  6.13511279e-02,  1.14746153e-01,
         3.54535617e-02,  2.55915433e-01,  7.26796240e-02,
        -1.99857190e-01],
       [-2.07879156e-01,  1.83034241e-02,  1.57775074e-01,
         6.06807172e-02, -1.75382420e-02,  1.33817732e-01,
         1.36331618e-01,  2.02472329e-01,  3.72610986e-02,
        -1.31865010e-01],
       [-9.93705392e-02,  6.03869818e-02, -4.28698361e-02,
         6.31842762e-04,  8.84034038e-02,  6.72685653e-02,
        -2.09506359e-02,  1.97081745e-01,  7.39021823e-02,
        -1.64300233e-01],
       [-9.71228778e-02,  5.48233166e-02,  1.38393641e-02,
        -7.14895800e-02, -3.87909710e-02,  8.45830888e-04,
        -3.62640694e-02,  1.64835989e-01,  5.04231751e-02,
        -2.07461655e-01],
       [-2.92240772e-02,  1.45425312e-02,  5.74428178e-02,
        -1.34241190e-02, -1.80013701e-02,  7.78546855e-02,
        -8.48746449e-02,  9.98296142e-02,  6.38790280e-02,
        -5.32845445e-02],
       [-1.76605240e-01, -1.42511949e-01,  1.39559209e-01,
        -2.00123414e-02, -6.44349307e-02, -4.56911251e-02,
         2.01093405e-03,  1.59898788e-01,  1.95391588e-02,
        -1.61375850e-01],
       [-1.58091724e-01,  6.25609234e-03,  2.12391287e-01,
        -1.39106885e-01, -4.78955358e-02,  7.36434534e-02,
         7.29984716e-02,  2.28351891e-01,  1.23042218e-01,
        -2.22285807e-01],
       [-6.63312748e-02, -5.25613949e-02,  3.88407931e-02,
         4.74876724e-02, -3.56937200e-02,  1.11578718e-01,
        -8.47167745e-02,  1.54049486e-01,  8.42248723e-02,
        -9.11155120e-02],
       [-1.49975002e-01, -1.69416200e-02,  2.03275681e-03,
         3.08024809e-02, -1.28081590e-02,  1.18468963e-01,
        -7.31947795e-02,  2.10938901e-01,  5.79604283e-02,
        -1.06384277e-01],
       [-2.44300172e-01,  6.77020177e-02,  1.61827058e-02,
         9.77846682e-02, -2.14450657e-02,  8.76296014e-02,
         1.55660659e-02,  2.56645411e-01, -6.94077387e-02,
         1.82542913e-02],
       [-3.24441910e-01,  2.83106230e-02,  1.15296148e-01,
        -6.49778843e-02, -3.93164232e-02,  2.09751099e-01,
         1.58456087e-01,  2.03075439e-01,  1.45919517e-01,
        -8.07187557e-02],
       [-1.77742794e-01, -3.47406045e-02,  6.37909994e-02,
         5.72632812e-02, -1.67798519e-01, -9.77907851e-02,
        -6.33480251e-02,  5.98776974e-02, -1.48319647e-01,
        -3.26665044e-02],
       [-1.92516297e-02, -4.32192907e-02,  9.45950896e-02,
        -1.24730960e-01,  3.15439701e-03,  7.49434829e-02,
         1.42610222e-01,  1.64739519e-01,  1.35794416e-01,
        -2.33872890e-01],
       [-9.74408463e-02, -4.51198146e-02, -7.16688111e-02,
         1.52820855e-01,  3.08901221e-02, -8.07915181e-02,
        -8.59454572e-02,  1.73750147e-01, -4.14928459e-02,
        -1.02175683e-01],
       [-1.79451153e-01,  7.97335058e-02,  6.08496368e-02,
        -8.74251127e-05,  1.40254274e-02,  7.78948367e-02,
         1.22523680e-02,  1.38402849e-01, -2.44962424e-03,
        -8.56248587e-02],
       [-7.16196820e-02, -3.66464853e-02, -1.97902359e-02,
        -3.42466384e-02,  1.01994909e-02,  8.11903924e-02,
         1.02423221e-01,  8.15625191e-02,  9.28392410e-02,
        -1.61639646e-01],
       [-1.29672050e-01, -9.39578265e-02, -3.77402268e-02,
        -5.66408038e-03,  2.01772340e-02, -5.53961843e-04,
         1.12603299e-01,  1.18293904e-01,  7.59286210e-02,
        -1.05032220e-01],
       [ 3.13648432e-02,  2.04140544e-02,  8.68844241e-02,
         8.54840502e-03, -3.24598253e-02,  7.13473856e-02,
         1.01958007e-01,  1.58244759e-01,  4.33884151e-02,
        -1.56489074e-01],
       [-5.69176152e-02, -8.68148059e-02,  5.83150014e-02,
        -6.94776773e-02, -1.14257783e-01,  9.14709717e-02,
        -6.18093796e-02,  4.60445434e-02,  6.21100292e-02,
        -2.56335258e-01],
       [-1.00941956e-03, -9.87592638e-02,  1.59144640e-01,
         2.46649459e-02, -1.47723123e-01,  3.34706903e-03,
        -1.25270292e-01,  7.13937655e-02, -3.65925357e-02,
        -2.86379248e-01],
       [-2.52649784e-01, -1.80219673e-02,  1.53900415e-01,
        -7.60671049e-02, -4.30139415e-02,  6.14799336e-02,
         5.27559966e-02,  3.91793013e-01,  1.10363506e-01,
        -2.21582249e-01],
       [-1.04441456e-02, -5.70102595e-02, -5.45391962e-02,
        -6.66194037e-02,  3.30452994e-02,  4.31669690e-03,
        -1.39387622e-02,  1.50821537e-01,  7.82721266e-02,
        -1.13290384e-01],
       [-1.50469467e-01, -1.50829509e-01,  1.37116134e-01,
        -7.71817416e-02, -1.22132301e-01,  8.29393342e-02,
         7.44771212e-03,  1.10161960e-01,  5.23409843e-02,
        -1.67824954e-01],
       [-1.67705536e-01, -1.61053427e-02,  3.56741399e-02,
        -8.12948644e-02, -2.15860698e-02,  7.68682212e-02,
         3.90296578e-02,  8.14016312e-02,  1.20665669e-01,
        -5.40915243e-02],
       [-1.74987361e-01,  5.39990142e-03,  7.59589747e-02,
         1.13510445e-01, -3.19063663e-02, -5.98092973e-02,
        -4.05801088e-02,  2.37588376e-01, -6.73733801e-02,
        -1.72320567e-02],
       [-1.80301860e-01,  2.00746767e-02, -7.40496814e-03,
         8.36828053e-02,  9.17709470e-02,  1.46025598e-01,
        -2.91051138e-02,  2.14360297e-01, -3.91696244e-02,
        -1.15331344e-01],
       [-7.45102018e-02,  3.96583155e-02,  8.10021013e-02,
         1.56707764e-02, -2.35380158e-02,  1.56681970e-01,
        -1.12800300e-02,  3.64681214e-01,  1.12793013e-01,
        -9.20613408e-02],
       [-1.10700965e-01, -3.84411961e-03,  7.15886354e-02,
        -5.16710430e-03, -2.68637538e-02, -4.64520939e-02,
        -1.02423206e-01,  1.41418934e-01,  1.36580504e-02,
        -2.16841191e-01],
       [-1.03602912e-02, -1.36248600e-02, -8.44807327e-02,
        -3.93018406e-03,  6.54329583e-02, -1.54229663e-02,
        -9.10714716e-02,  1.13576502e-02,  6.24551401e-02,
        -1.10215969e-01],
       [-1.64637700e-01, -4.25843447e-02, -6.63272589e-02,
         1.01544857e-02,  9.00160298e-02,  1.41169682e-01,
         9.43019092e-02,  1.50300652e-01,  1.17022656e-01,
        -2.61101604e-01],
       [-2.96755701e-01,  1.48339659e-01,  5.29592186e-02,
         4.51779664e-02, -6.84008598e-02,  1.29287004e-01,
         1.34066977e-02,  1.68794006e-01, -1.53631158e-02,
        -1.40826374e-01],
       [-2.27824658e-01, -3.58637236e-02,  7.98013210e-02,
        -2.93148141e-02, -1.29889801e-01,  1.07304119e-02,
         6.16377033e-02,  2.38016129e-01,  1.68460131e-01,
        -2.78131723e-01],
       [-1.97686747e-01, -1.20533034e-01,  1.91476271e-02,
        -2.50333622e-02, -1.20231688e-01, -1.43363982e-01,
        -5.45644462e-02,  1.13663480e-01, -9.71207619e-02,
        -7.38224685e-02],
       [-1.21181801e-01, -9.18156952e-02,  1.72619522e-02,
         7.20846877e-02, -5.00237271e-02, -7.88232982e-02,
        -2.75398232e-02,  9.42765027e-02, -8.18064660e-02,
        -4.43772227e-02],
       [-2.12152809e-01, -1.05831539e-02,  1.12541884e-01,
         3.79703306e-02, -4.97136004e-02, -8.26531351e-02,
         4.28089425e-02,  2.72401571e-01, -9.41082910e-02,
        -8.25358368e-03],
       [-2.12490350e-01,  5.10787666e-02, -4.91231680e-03,
         1.71558380e-01,  8.33496898e-02,  8.03120583e-02,
         5.97136915e-02,  2.78716445e-01, -5.66011816e-02,
        -7.99765587e-02],
       [-2.45497763e-01, -5.21367639e-02,  1.77163050e-01,
         8.67958441e-02, -1.33168459e-01,  9.83412005e-03,
        -1.34591311e-01,  1.48744047e-01, -6.65533617e-02,
        -1.07505932e-01],
       [-1.36525869e-01, -5.12802340e-02,  2.54329219e-02,
         8.01228657e-02, -3.24120894e-02, -6.36913255e-03,
        -7.75915161e-02,  1.81387305e-01,  6.72850609e-02,
        -1.06104709e-01],
       [-8.19087848e-02, -6.67821616e-02,  1.09396182e-01,
        -8.99944529e-02, -1.08385280e-01,  6.29347712e-02,
         7.26154894e-02,  1.68957621e-01,  1.90485001e-01,
        -2.60798335e-01],
       [-1.76897705e-01,  4.90825251e-02,  2.94402167e-02,
        -2.41212249e-02,  3.94896790e-02,  1.18754521e-01,
         1.69773921e-02,  1.10196158e-01,  7.08303824e-02,
        -6.86142594e-02],
       [-1.29656106e-01, -8.14089552e-02,  1.14682741e-01,
        -1.32834181e-01, -1.49253279e-01, -2.83164792e-02,
         3.45680863e-04,  2.52322882e-01,  2.89388448e-02,
        -2.79281288e-01],
       [-1.10502213e-01,  1.07094124e-01,  3.24486196e-02,
         7.70951509e-02, -6.27939776e-02,  1.68845624e-01,
        -1.44310594e-01,  1.45337492e-01,  2.03377791e-02,
        -5.04231378e-02],
       [-2.66523331e-01, -7.49082193e-02,  1.91363335e-01,
        -6.39847219e-02, -1.04055285e-01,  8.31385702e-02,
         8.82939398e-02,  1.99207246e-01,  5.35239354e-02,
        -2.60884434e-01],
       [-1.35722771e-01,  3.94147262e-02, -6.39424995e-02,
         1.39283150e-01,  5.37211001e-02, -6.34303223e-03,
        -1.70467123e-01,  2.55692095e-01, -7.66103566e-02,
        -6.90388680e-02],
       [-1.07885860e-01,  2.30858717e-02,  8.21547359e-02,
        -3.12240291e-02, -9.89983678e-02,  7.22398609e-02,
        -4.08478230e-02,  8.69123414e-02,  4.48577479e-02,
        -6.41947538e-02],
       [-2.28321850e-02, -3.88411283e-02,  1.47033811e-01,
        -2.35385150e-01, -9.87000838e-02,  6.44287840e-02,
        -1.87633559e-02,  1.17905587e-01,  9.70625877e-02,
        -2.46781930e-01],
       [-8.77917856e-02, -1.64044406e-02,  7.53755122e-02,
        -8.24043527e-04, -7.77238905e-02,  1.16269790e-01,
        -1.00877963e-01,  8.79124254e-02,  3.39440927e-02,
        -5.94997481e-02],
       [-1.41677827e-01, -1.40151009e-02,  8.84927809e-04,
         1.03166051e-01, -1.66242346e-02,  2.62837298e-02,
        -1.33589238e-01,  1.65735006e-01,  3.65820900e-02,
        -1.46895535e-02],
       [-1.61557034e-01,  5.66626638e-02, -1.61597617e-02,
         2.58595943e-02,  3.39905620e-02,  1.01104185e-01,
        -3.71510983e-02,  1.20341092e-01,  3.26242894e-02,
        -4.07250933e-02],
       [-2.17516154e-01,  7.85727724e-02,  9.79433060e-02,
         6.97179586e-02,  4.95264679e-02,  1.92503840e-01,
        -4.96265218e-02,  1.99431688e-01, -5.32730669e-03,
        -2.50038877e-02],
       [-1.35356426e-01, -6.96291253e-02,  3.92658785e-02,
        -9.86322537e-02, -4.20986377e-02,  9.87840891e-02,
         9.67663303e-02,  1.76262826e-01,  9.44406465e-02,
        -2.23472387e-01],
       [-1.25066608e-01,  7.71146417e-02,  4.02672291e-02,
        -2.05352344e-02,  3.11498251e-02,  9.64582711e-02,
        -5.39951548e-02,  2.29750067e-01,  1.61451437e-02,
        -5.41997403e-02],
       [-1.93750665e-01, -3.56721133e-03, -1.50568932e-02,
         1.78796798e-02,  8.33508372e-03, -1.18013099e-02,
        -5.35021350e-02,  2.02244624e-01,  3.02494057e-02,
        -1.20312274e-01],
       [-2.62067527e-01,  2.36408859e-02,  5.58489896e-02,
         1.75756812e-01, -2.75299139e-02,  3.48872915e-02,
         5.41301072e-03,  3.15880209e-01, -5.74782193e-02,
         7.00992346e-03],
       [-2.76674211e-01, -2.08131559e-02, -1.26259401e-02,
         7.77718723e-02, -1.54706314e-01,  1.31996438e-01,
         2.20355690e-02,  5.61908968e-02,  3.73308063e-02,
        -1.17717944e-01],
       [-1.59806639e-01,  1.20503023e-01, -4.36934829e-03,
         1.16428092e-01,  5.47975339e-02,  1.25162587e-01,
         4.78192419e-02,  1.28253624e-01,  7.34245628e-02,
        -1.80039048e-01],
       [-2.67963678e-01,  6.00077920e-02,  1.13472804e-01,
         7.52071738e-02, -6.40357211e-02,  1.03171021e-01,
         1.48901194e-01,  1.97019696e-01,  3.76104042e-02,
        -1.68720663e-01],
       [-2.01240778e-01,  2.47026011e-02,  3.10055390e-02,
        -8.58910009e-03, -8.49897265e-02, -7.54948407e-02,
        -9.39515531e-02,  1.34306327e-01, -1.71037674e-01,
        -5.76597378e-02],
       [-5.20152375e-02,  6.59879148e-02, -3.30656916e-02,
         9.97125208e-02,  3.56362388e-02,  1.26982957e-01,
        -2.69417539e-02,  1.59046397e-01,  1.10872082e-01,
        -1.84650719e-01]], dtype=float32)>}

Bạn cũng có thể tải và suy luận theo cách phân tán:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Việc gọi hàm khôi phục chỉ là chuyển tiếp trên mô hình đã lưu (dự đoán). Nếu bạn muốn tiếp tục đào tạo hàm đã tải thì sao? Hoặc nhúng chức năng đã tải vào một mô hình lớn hơn? Một thực tế phổ biến là quấn đối tượng đã tải này vào một lớp Keras để đạt được điều này. May mắn thay, TF Hubhub.KerasLayer cho mục đích này, được hiển thị ở đây:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 2s 3ms/step - loss: 0.1981 - sparse_categorical_accuracy: 0.9412
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0655 - sparse_categorical_accuracy: 0.9804

Như bạn có thể thấy, hub.KerasLayer gói kết quả được tải trở lại từ tf.saved_model.load() vào một lớp Keras có thể được sử dụng để xây dựng một mô hình khác. Điều này rất hữu ích cho việc học chuyển tiếp.

Tôi nên sử dụng API nào?

Để tiết kiệm, nếu bạn đang làm việc với mô hình keras, bạn nên sử dụng API model.save() của model.save() . Nếu những gì bạn đang lưu không phải là mô hình Keras, thì API cấp thấp hơn là sự lựa chọn duy nhất của bạn.

Để tải, bạn sử dụng API nào phụ thuộc vào những gì bạn muốn lấy từ API tải. Nếu bạn không thể (hoặc không muốn) lấy một mô hình Keras, thì hãy sử dụng tf.saved_model.load() . Nếu không, hãy sử dụng tf.keras.models.load_model() . Lưu ý rằng bạn chỉ có thể lấy lại một mô hình Keras nếu bạn đã lưu một mô hình Keras.

Có thể trộn và kết hợp các API. Bạn có thể lưu mô hình Keras bằng model.save và tải mô hình không phải Keras bằng API cấp thấp, tf.saved_model.load .

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Lưu / Đang tải từ thiết bị cục bộ

Khi lưu và tải từ thiết bị io cục bộ trong khi chạy từ xa, chẳng hạn như sử dụng TPU đám mây, tùy chọn experimental_io_device phải được sử dụng để đặt thiết bị io thành localhost.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Cảnh báo

Một trường hợp đặc biệt là khi bạn có một mô hình Keras không có các đầu vào được xác định rõ ràng. Ví dụ: mô hình Tuần tự có thể được tạo mà không có bất kỳ hình dạng đầu vào nào ( Sequential([Dense(3), ...] ). Mô hình lớp con cũng không có đầu vào được xác định rõ sau khi khởi tạo. Trong trường hợp này, bạn nên sử dụng API cấp thấp hơn về cả lưu và tải, nếu không bạn sẽ gặp lỗi.

Để kiểm tra xem mô hình của bạn có các đầu vào được xác định rõ hay không, chỉ cần kiểm tra xem model.inputs có phải là None . Nếu nó không phải là None , tất cả các bạn đều tốt. Hình dạng đầu vào được tự động xác định khi mô hình được sử dụng trong .fit , .evaluate , .predict hoặc khi gọi mô hình ( model(inputs) ).

Đây là một ví dụ:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f74d29fffd0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f74d2b37cc0>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets