Trang này được dịch bởi Cloud Translation API.
Switch to English

Lưu và tải các mô hình Keras

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ tay

Giới thiệu

Một mô hình Keras bao gồm nhiều thành phần:

  • Một kiến ​​trúc hoặc cấu hình, chỉ định các lớp mà mô hình chứa và cách chúng được kết nối.
  • Tập hợp các giá trị trọng số ("trạng thái của mô hình").
  • Trình tối ưu hóa (được xác định bằng cách biên dịch mô hình).
  • Một tập hợp các tổn thất và chỉ số (được xác định bằng cách biên dịch mô hình hoặc gọi add_loss() hoặc add_metric() ).

API Keras giúp bạn có thể lưu tất cả những phần này vào đĩa cùng một lúc hoặc chỉ lưu một cách chọn lọc một số trong số chúng:

  • Lưu mọi thứ vào một kho lưu trữ duy nhất ở định dạng TensorFlow SavedModel (hoặc ở định dạng Keras H5 cũ hơn). Đây là thông lệ tiêu chuẩn.
  • Chỉ lưu kiến ​​trúc / cấu hình, thường dưới dạng tệp JSON.
  • Chỉ lưu các giá trị trọng số. Điều này thường được sử dụng khi đào tạo mô hình.

Hãy xem xét từng tùy chọn sau: khi nào bạn sử dụng tùy chọn này hay tùy chọn khác? Họ làm việc như thế nào?

Câu trả lời ngắn gọn để lưu và tải

Nếu bạn chỉ có 10 giây để đọc hướng dẫn này, đây là những gì bạn cần biết.

Lưu mô hình Keras:

model = ...  # Get model (Sequential, Functional Model, or Model subclass)
model.save('path/to/location')

Đang tải lại mô hình:

from tensorflow import keras
model = keras.models.load_model('path/to/location')

Bây giờ, chúng ta hãy xem xét chi tiết.

Thiết lập

import numpy as np
import tensorflow as tf
from tensorflow import keras

Lưu và tải toàn bộ mô hình

Bạn có thể lưu toàn bộ mô hình vào một hiện vật duy nhất. Nó sẽ bao gồm:

  • Cấu hình / kiến ​​trúc của mô hình
  • Giá trị trọng lượng của mô hình (được học trong quá trình đào tạo)
  • Thông tin biên dịch của mô hình (if compile() ) được gọi
  • Trình tối ưu hóa và trạng thái của nó, nếu có (điều này cho phép bạn bắt đầu lại quá trình đào tạo ở nơi bạn đã rời đi)

API

Có hai định dạng bạn có thể sử dụng để lưu toàn bộ mô hình vào đĩa: định dạng TensorFlow SavedModel và định dạng Keras H5 cũ hơn . Định dạng được đề xuất là SavedModel. Nó là mặc định khi bạn sử dụng model.save() .

Bạn có thể chuyển sang định dạng H5 bằng cách:

  • Chuyển save_format='h5' để save() .
  • Chuyển tên tệp kết thúc bằng .h5 hoặc .keras để save() .

Định dạng SavedModel

Thí dụ:

def get_model():
    # Create a simple model.
    inputs = keras.Input(shape=(32,))
    outputs = keras.layers.Dense(1)(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer="adam", loss="mean_squared_error")
    return model


model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model')` creates a SavedModel folder `my_model`.
model.save("my_model")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_model")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)
4/4 [==============================] - 0s 1ms/step - loss: 0.9931
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: my_model/assets
4/4 [==============================] - 0s 995us/step - loss: 0.8820

<tensorflow.python.keras.callbacks.History at 0x7fb18c7b4748>

SavedModel chứa những gì

Gọi model.save('my_model') tạo một thư mục có tên là my_model , chứa thông tin sau:

ls my_model
assets  saved_model.pb  variables

Kiến trúc mô hình và cấu hình đào tạo (bao gồm trình tối ưu hóa, tổn thất và số liệu) được lưu trữ trong saved_model.pb . Các trọng số được lưu trong thư mục variables/ .

Để biết thông tin chi tiết về định dạng SavedModel, hãy xem hướng dẫn SavedModel ( Định dạng SavedModel trên đĩa ) .

Cách SavedModel xử lý các đối tượng tùy chỉnh

Khi lưu mô hình và các lớp của nó, định dạng SavedModel sẽ lưu trữ tên lớp, hàm gọi , tổn thất và trọng số (và cấu hình, nếu được triển khai). Hàm gọi xác định đồ thị tính toán của mô hình / lớp.

Trong trường hợp không có cấu hình mô hình / lớp, hàm gọi được sử dụng để tạo mô hình tồn tại giống như mô hình ban đầu có thể được huấn luyện, đánh giá và sử dụng để suy luận.

Tuy nhiên, nó luôn luôn là một thực hành tốt để xác định get_configfrom_config phương pháp khi viết một mô hình tùy chỉnh hoặc lớp lớp. Điều này cho phép bạn dễ dàng cập nhật tính toán sau này nếu cần. Xem phần về Đối tượng tùy chỉnh để biết thêm thông tin.

Dưới đây là ví dụ về những gì sẽ xảy ra khi tải các lớp tùy chỉnh từ định dạng SavedModel mà không ghi đè các phương thức cấu hình.

class CustomModel(keras.Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]

    def call(self, inputs):
        x = inputs
        for layer in self.dense_layers:
            x = layer(x)
        return x


model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs = model(input_arr)
model.save("my_model")

# Delete the custom-defined model class to ensure that the loader does not have
# access to it.
del CustomModel

loaded = keras.models.load_model("my_model")
np.testing.assert_allclose(loaded(input_arr), outputs)

print("Original model:", model)
print("Loaded model:", loaded)
INFO:tensorflow:Assets written to: my_model/assets
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Original model: <__main__.CustomModel object at 0x7fb210749080>
Loaded model: <tensorflow.python.keras.saving.saved_model.load.CustomModel object at 0x7fb1f9ee84e0>

Như đã thấy trong ví dụ trên, trình nạp động tạo ra một lớp mô hình mới hoạt động giống như mô hình ban đầu.

Định dạng Keras H5

Keras cũng hỗ trợ lưu một tệp HDF5 duy nhất chứa kiến ​​trúc của mô hình, các giá trị trọng số và thông tin compile() . Nó là một giải pháp thay thế trọng lượng nhẹ cho SavedModel.

Thí dụ:

model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.
model.save("my_h5_model.h5")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_h5_model.h5")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)
4/4 [==============================] - 0s 1ms/step - loss: 0.5105
4/4 [==============================] - 0s 1ms/step - loss: 0.4567

<tensorflow.python.keras.callbacks.History at 0x7fb1f954dcc0>

Hạn chế

So với định dạng SavedModel, có hai thứ không được đưa vào tệp H5:

  • Tổn thất & chỉ số bên ngoài được thêm qua model.add_loss() & model.add_metric() không được lưu (không giống như SavedModel). Nếu bạn có những tổn thất & số liệu như vậy trên mô hình của mình và bạn muốn tiếp tục đào tạo, bạn cần phải tự cộng những tổn thất này sau khi tải mô hình. Lưu ý rằng điều này không áp dụng cho các khoản lỗ / chỉ số được tạo bên trong các lớp thông qua self.add_loss() & self.add_metric() . Miễn là lớp được tải, những tổn thất & số liệu này sẽ được giữ lại, vì chúng là một phần của phương thức call của lớp.
  • Biểu đồ tính toán của các đối tượng tùy chỉnh như các lớp tùy chỉnh không được bao gồm trong tệp đã lưu. Tại thời điểm tải, Keras sẽ cần quyền truy cập vào các lớp / hàm Python của các đối tượng này để xây dựng lại mô hình. Xem Đối tượng tùy chỉnh .

Lưu kiến ​​trúc

Cấu hình (hoặc kiến ​​trúc) của mô hình chỉ định các lớp mà mô hình chứa và cách các lớp này được kết nối *. Nếu bạn có cấu hình của một mô hình, thì mô hình có thể được tạo với trạng thái mới khởi tạo cho các trọng số và không có thông tin biên dịch.

* Lưu ý rằng điều này chỉ áp dụng cho các mô hình được xác định bằng cách sử dụng ứng dụng chức năng hoặc Tuần tự không phải là mô hình phân lớp.

Cấu hình của mô hình tuần tự hoặc mô hình API chức năng

Các loại mô hình này là biểu đồ rõ ràng của các lớp: cấu hình của chúng luôn có sẵn ở dạng có cấu trúc.

API

get_config()from_config()

Gọi config = model.get_config() sẽ trả về một lệnh Python chứa cấu hình của mô hình. Sau đó, mô hình tương tự có thể được tạo lại thông qua Sequential.from_config(config) (đối với mô hình Sequential ) hoặc Model.from_config(config) (đối với mô hình API chức năng).

Quy trình làm việc tương tự cũng hoạt động cho bất kỳ lớp nào có thể tuần tự hóa.

Ví dụ về lớp:

layer = keras.layers.Dense(3, activation="relu")
layer_config = layer.get_config()
new_layer = keras.layers.Dense.from_config(layer_config)

Ví dụ về mô hình tuần tự:

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
config = model.get_config()
new_model = keras.Sequential.from_config(config)

Ví dụ mô hình chức năng:

inputs = keras.Input((32,))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
config = model.get_config()
new_model = keras.Model.from_config(config)

to_json()tf.keras.models.model_from_json()

Điều này tương tự với get_config / from_config , ngoại trừ nó biến mô hình thành một chuỗi JSON, sau đó có thể được tải mà không cần lớp mô hình ban đầu. Nó cũng dành riêng cho các mô hình, nó không dành cho các lớp.

Thí dụ:

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
json_config = model.to_json()
new_model = keras.models.model_from_json(json_config)

Đối tượng tùy chỉnh

Mô hình và lớp

Kiến trúc của các mô hình và lớp phân lớp con được định nghĩa trong các phương thức __init__call . Chúng được coi là mã bytecode của Python, không thể được tuần tự hóa thành cấu hình tương thích với JSON - bạn có thể thử tuần tự hóa mã bytecode (ví dụ: thông qua pickle ), nhưng nó hoàn toàn không an toàn và có nghĩa là mô hình của bạn không thể được tải trên một hệ thống khác.

Để lưu / tải một mô hình với các lớp được xác định tùy chỉnh hoặc một mô hình phân lớp, bạn nên ghi đè các phương thức get_configfrom_config tùy chọn. Ngoài ra, bạn nên sử dụng đăng ký đối tượng tùy chỉnh để Keras biết về nó.

Các chức năng tùy chỉnh

Các hàm do tùy chỉnh xác định (ví dụ: mất kích hoạt hoặc khởi tạo) không cần phương thức get_config . Tên hàm đủ để tải miễn là nó được đăng ký như một đối tượng tùy chỉnh.

Chỉ tải biểu đồ TensorFlow

Có thể tải biểu đồ TensorFlow do Keras tạo ra. Nếu bạn làm như vậy, bạn sẽ không cần cung cấp bất kỳ custom_objects nào. Bạn có thể làm như thế này:

model.save("my_model")
tensorflow_graph = tf.saved_model.load("my_model")
x = np.random.uniform(size=(4, 32)).astype(np.float32)
predicted = tensorflow_graph(x).numpy()
INFO:tensorflow:Assets written to: my_model/assets

Lưu ý rằng phương pháp này có một số nhược điểm:

  • Vì lý do truy xuất nguồn gốc, bạn phải luôn có quyền truy cập vào các đối tượng tùy chỉnh đã được sử dụng. Bạn sẽ không muốn đưa vào sản xuất một mô hình mà bạn không thể tạo lại.
  • Đối tượng được trả về bởi tf.saved_model.load không phải là mô hình Keras. Vì vậy, nó không dễ sử dụng. Ví dụ: bạn sẽ không có quyền truy cập vào .predict() hoặc .fit()

Ngay cả khi việc sử dụng nó không được khuyến khích, nó có thể giúp bạn nếu bạn đang gặp khó khăn, ví dụ: nếu bạn mất mã của các đối tượng tùy chỉnh của mình hoặc gặp sự cố khi tải mô hình với tf.keras.models.load_model() .

Bạn có thể tìm hiểu thêm trong trang về tf.saved_model.load

Xác định các phương thức cấu hình

Thông số kỹ thuật:

  • get_config phải trả về một từ điển JSON-serializable để tương thích với các API tiết kiệm mô hình và kiến ​​trúc Keras.
  • from_config(config) ( classmethod ) sẽ trả về một lớp hoặc đối tượng mô hình mới được tạo từ cấu hình. Việc triển khai mặc định trả về cls(**config) .

Thí dụ:

class CustomLayer(keras.layers.Layer):
    def __init__(self, a):
        self.var = tf.Variable(a, name="var_a")

    def call(self, inputs, training=False):
        if training:
            return inputs * self.var
        else:
            return inputs

    def get_config(self):
        return {"a": self.var.numpy()}

    # There's actually no need to define `from_config` here, since returning
    # `cls(**config)` is the default behavior.
    @classmethod
    def from_config(cls, config):
        return cls(**config)


layer = CustomLayer(5)
layer.var.assign(2)

serialized_layer = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
    serialized_layer, custom_objects={"CustomLayer": CustomLayer}
)

Đăng ký đối tượng tùy chỉnh

Keras giữ một ghi chú về lớp nào đã tạo cấu hình. Từ ví dụ trên, tf.keras.layers.serialize tạo một dạng tuần tự hóa của lớp tùy chỉnh:

{'class_name': 'CustomLayer', 'config': {'a': 2} }

Keras giữ một danh sách tổng thể của tất cả các lớp, mô hình, trình tối ưu hóa và chỉ số được tích hợp sẵn, được sử dụng để tìm lớp chính xác để gọi from_config . Nếu không thể tìm thấy lớp, thì một lỗi sẽ xuất hiện ( Value Error: Unknown layer ). Có một số cách để đăng ký các lớp tùy chỉnh vào danh sách này:

  1. Đặt đối số custom_objects trong hàm tải. (xem ví dụ trong phần trên "Định nghĩa các phương pháp cấu hình")
  2. tf.keras.utils.custom_object_scope hoặc tf.keras.utils.CustomObjectScope
  3. tf.keras.utils.register_keras_serializable

Ví dụ về lớp và chức năng tùy chỉnh

class CustomLayer(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(CustomLayer, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(CustomLayer, self).get_config()
        config.update({"units": self.units})
        return config


def custom_activation(x):
    return tf.nn.tanh(x) ** 2


# Make a model with the CustomLayer and custom_activation
inputs = keras.Input((32,))
x = CustomLayer(32)(inputs)
outputs = keras.layers.Activation(custom_activation)(x)
model = keras.Model(inputs, outputs)

# Retrieve the config
config = model.get_config()

# At loading time, register the custom objects with a `custom_object_scope`:
custom_objects = {"CustomLayer": CustomLayer, "custom_activation": custom_activation}
with keras.utils.custom_object_scope(custom_objects):
    new_model = keras.Model.from_config(config)

Nhân bản mô hình trong bộ nhớ

Bạn cũng có thể sao chép mô hình trong bộ nhớ thông qua tf.keras.models.clone_model() . Điều này tương đương với việc lấy cấu hình sau đó tạo lại mô hình từ cấu hình của nó (vì vậy nó không bảo toàn thông tin biên dịch hoặc các giá trị trọng số lớp).

Thí dụ:

with keras.utils.custom_object_scope(custom_objects):
    new_model = keras.models.clone_model(model)

Chỉ lưu và tải các giá trị trọng lượng của mô hình

Bạn có thể chọn chỉ lưu và tải trọng lượng của mô hình. Điều này có thể hữu ích nếu:

  • Bạn chỉ cần mô hình để suy luận: trong trường hợp này, bạn sẽ không cần phải khởi động lại quá trình đào tạo, vì vậy bạn không cần thông tin biên dịch hoặc trạng thái trình tối ưu hóa.
  • Bạn đang học chuyển giao: trong trường hợp này, bạn sẽ đào tạo một mô hình mới sử dụng lại trạng thái của mô hình trước đó, vì vậy bạn không cần thông tin biên dịch của mô hình trước đó.

API để truyền trọng lượng trong bộ nhớ

Trọng lượng có thể được sao chép giữa các đối tượng khác nhau bằng cách sử dụng get_weightsset_weights :

Ví dụ bên dưới.

Chuyển trọng số từ lớp này sang lớp khác, trong bộ nhớ

def create_layer():
    layer = keras.layers.Dense(64, activation="relu", name="dense_2")
    layer.build((None, 784))
    return layer


layer_1 = create_layer()
layer_2 = create_layer()

# Copy weights from layer 2 to layer 1
layer_2.set_weights(layer_1.get_weights())

Chuyển trọng số từ mô hình này sang mô hình khác với kiến ​​trúc tương thích, trong bộ nhớ

# Create a simple functional model
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

# Define a subclassed model with the same architecture
class SubclassedModel(keras.Model):
    def __init__(self, output_dim, name=None):
        super(SubclassedModel, self).__init__(name=name)
        self.output_dim = output_dim
        self.dense_1 = keras.layers.Dense(64, activation="relu", name="dense_1")
        self.dense_2 = keras.layers.Dense(64, activation="relu", name="dense_2")
        self.dense_3 = keras.layers.Dense(output_dim, name="predictions")

    def call(self, inputs):
        x = self.dense_1(inputs)
        x = self.dense_2(x)
        x = self.dense_3(x)
        return x

    def get_config(self):
        return {"output_dim": self.output_dim, "name": self.name}


subclassed_model = SubclassedModel(10)
# Call the subclassed model once to create the weights.
subclassed_model(tf.ones((1, 784)))

# Copy weights from functional_model to subclassed_model.
subclassed_model.set_weights(functional_model.get_weights())

assert len(functional_model.weights) == len(subclassed_model.weights)
for a, b in zip(functional_model.weights, subclassed_model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

Trường hợp của các lớp không trạng thái

Bởi vì các lớp không trạng thái không thay đổi thứ tự hoặc số lượng trọng lượng, các mô hình có thể có kiến ​​trúc tương thích ngay cả khi có thêm / thiếu các lớp không trạng thái.

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)

# Add a dropout layer, which does not contain any weights.
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model_with_dropout = keras.Model(
    inputs=inputs, outputs=outputs, name="3_layer_mlp"
)

functional_model_with_dropout.set_weights(functional_model.get_weights())

API để lưu trọng lượng vào đĩa và tải chúng trở lại

Trọng lượng có thể được lưu vào đĩa bằng cách gọi model.save_weights ở các định dạng sau:

  • Điểm kiểm tra TensorFlow
  • HDF5

Định dạng mặc định cho model.save_weights là điểm kiểm tra TensorFlow. Có hai cách để chỉ định định dạng lưu:

  1. save_format số save_format : Đặt giá trị thành save_format="tf" hoặc save_format="h5" .
  2. đối số path : Nếu đường dẫn kết thúc bằng .h5 hoặc .hdf5 thì định dạng HDF5 được sử dụng. Các hậu tố khác sẽ dẫn đến một điểm kiểm tra TensorFlow trừ khi save_format được đặt.

Ngoài ra còn có một tùy chọn truy xuất các trọng số dưới dạng mảng numpy trong bộ nhớ. Mỗi API đều có những ưu và nhược điểm được trình bày chi tiết bên dưới.

Định dạng điểm kiểm tra TF

Thí dụ:

# Runnable example
sequential_model = keras.Sequential(
    [
        keras.Input(shape=(784,), name="digits"),
        keras.layers.Dense(64, activation="relu", name="dense_1"),
        keras.layers.Dense(64, activation="relu", name="dense_2"),
        keras.layers.Dense(10, name="predictions"),
    ]
)
sequential_model.save_weights("ckpt")
load_status = sequential_model.load_weights("ckpt")

# `assert_consumed` can be used as validation that all variable values have been
# restored from the checkpoint. See `tf.train.Checkpoint.restore` for other
# methods in the Status object.
load_status.assert_consumed()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb1f8be3080>

Định dạng chi tiết

Định dạng TensorFlow Checkpoint lưu và khôi phục các trọng số bằng cách sử dụng tên thuộc tính đối tượng. Ví dụ, hãy xem xét lớp tf.keras.layers.Dense . Lớp chứa hai trọng lượng: dense.kerneldense.bias . Khi lớp được lưu ở định dạng tf , điểm kiểm tra kết quả chứa các khóa "kernel""bias" và các giá trị trọng số tương ứng của chúng. Để biết thêm thông tin, hãy xem "Cơ chế tải" trong hướng dẫn Điểm kiểm tra TF .

Lưu ý rằng thuộc tính / cạnh đồ thị được đặt tên theo tên được sử dụng trong đối tượng mẹ, không phải tên của biến . Hãy xem xét CustomLayer trong ví dụ bên dưới. Biến CustomLayer.var được lưu với "var" như một phần của khóa, không phải "var_a" .

class CustomLayer(keras.layers.Layer):
    def __init__(self, a):
        self.var = tf.Variable(a, name="var_a")


layer = CustomLayer(5)
layer_ckpt = tf.train.Checkpoint(layer=layer).save("custom_layer")

ckpt_reader = tf.train.load_checkpoint(layer_ckpt)

ckpt_reader.get_variable_to_dtype_map()
{'save_counter/.ATTRIBUTES/VARIABLE_VALUE': tf.int64,
 '_CHECKPOINTABLE_OBJECT_GRAPH': tf.string,
 'layer/var/.ATTRIBUTES/VARIABLE_VALUE': tf.int32}

Chuyển giao ví dụ học tập

Về cơ bản, miễn là hai mô hình có cùng kiến ​​trúc, chúng có thể chia sẻ cùng một trạm kiểm soát.

Thí dụ:

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

# Extract a portion of the functional model defined in the Setup section.
# The following lines produce a new model that excludes the final output
# layer of the functional model.
pretrained = keras.Model(
    functional_model.inputs, functional_model.layers[-1].input, name="pretrained_model"
)
# Randomly assign "trained" weights.
for w in pretrained.weights:
    w.assign(tf.random.normal(w.shape))
pretrained.save_weights("pretrained_ckpt")
pretrained.summary()

# Assume this is a separate program where only 'pretrained_ckpt' exists.
# Create a new functional model with a different output dimension.
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(5, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="new_model")

# Load the weights from pretrained_ckpt into model.
model.load_weights("pretrained_ckpt")

# Check that all of the pretrained weights have been loaded.
for a, b in zip(pretrained.weights, model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

print("\n", "-" * 50)
model.summary()

# Example 2: Sequential model
# Recreate the pretrained model, and load the saved weights.
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
pretrained_model = keras.Model(inputs=inputs, outputs=x, name="pretrained")

# Sequential example:
model = keras.Sequential([pretrained_model, keras.layers.Dense(5, name="predictions")])
model.summary()

pretrained_model.load_weights("pretrained_ckpt")

# Warning! Calling `model.load_weights('pretrained_ckpt')` won't throw an error,
# but will *not* work as expected. If you inspect the weights, you'll see that
# none of the weights will have loaded. `pretrained_model.load_weights()` is the
# correct method to call.
Model: "pretrained_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
=================================================================
Total params: 54,400
Trainable params: 54,400
Non-trainable params: 0
_________________________________________________________________

 --------------------------------------------------
Model: "new_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
predictions (Dense)          (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
pretrained (Functional)      (None, 64)                54400     
_________________________________________________________________
predictions (Dense)          (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb1f8bad080>

Thông thường, bạn nên sử dụng cùng một API để xây dựng mô hình. Nếu bạn chuyển đổi giữa Tuần tự và Chức năng, hoặc Chức năng và lớp con, v.v., thì hãy luôn xây dựng lại mô hình được đào tạo trước và tải các trọng số được đào tạo trước vào mô hình đó.

Câu hỏi tiếp theo là, làm thế nào trọng số có thể được lưu và tải vào các mô hình khác nhau nếu kiến ​​trúc mô hình khá khác nhau? Giải pháp là sử dụng tf.train.Checkpoint để lưu và khôi phục các lớp / biến chính xác.

Thí dụ:

# Create a subclassed model that essentially uses functional_model's first
# and last layers.
# First, save the weights of functional_model's first and last dense layers.
first_dense = functional_model.layers[1]
last_dense = functional_model.layers[-1]
ckpt_path = tf.train.Checkpoint(
    dense=first_dense, kernel=last_dense.kernel, bias=last_dense.bias
).save("ckpt")

# Define the subclassed model.
class ContrivedModel(keras.Model):
    def __init__(self):
        super(ContrivedModel, self).__init__()
        self.first_dense = keras.layers.Dense(64)
        self.kernel = self.add_variable("kernel", shape=(64, 10))
        self.bias = self.add_variable("bias", shape=(10,))

    def call(self, inputs):
        x = self.first_dense(inputs)
        return tf.matmul(x, self.kernel) + self.bias


model = ContrivedModel()
# Call model on inputs to create the variables of the dense layer.
_ = model(tf.ones((1, 784)))

# Create a Checkpoint with the same structure as before, and load the weights.
tf.train.Checkpoint(
    dense=model.first_dense, kernel=model.kernel, bias=model.bias
).restore(ckpt_path).assert_consumed()
WARNING:tensorflow:From <ipython-input-1-eec1d28bc826>:15: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb1f8b37ef0>

Định dạng HDF5

Định dạng HDF5 chứa các trọng số được nhóm theo tên lớp. Các trọng lượng là danh sách được sắp xếp theo thứ tự bằng cách nối danh sách các trọng lượng có thể huấn luyện với danh sách các trọng lượng không thể huấn luyện (giống như layer.weights . layer.weights ). Do đó, một mô hình có thể sử dụng điểm kiểm tra hdf5 nếu nó có cùng các lớp và trạng thái có thể đào tạo như đã lưu trong điểm kiểm tra.

Thí dụ:

# Runnable example
sequential_model = keras.Sequential(
    [
        keras.Input(shape=(784,), name="digits"),
        keras.layers.Dense(64, activation="relu", name="dense_1"),
        keras.layers.Dense(64, activation="relu", name="dense_2"),
        keras.layers.Dense(10, name="predictions"),
    ]
)
sequential_model.save_weights("weights.h5")
sequential_model.load_weights("weights.h5")

Lưu ý rằng việc thay đổi layer.trainable có thể dẫn đến thứ tự layer.weights khác khi mô hình chứa các lớp lồng nhau.

class NestedDenseLayer(keras.layers.Layer):
    def __init__(self, units, name=None):
        super(NestedDenseLayer, self).__init__(name=name)
        self.dense_1 = keras.layers.Dense(units, name="dense_1")
        self.dense_2 = keras.layers.Dense(units, name="dense_2")

    def call(self, inputs):
        return self.dense_2(self.dense_1(inputs))


nested_model = keras.Sequential([keras.Input((784,)), NestedDenseLayer(10, "nested")])
variable_names = [v.name for v in nested_model.weights]
print("variables: {}".format(variable_names))

print("\nChanging trainable status of one of the nested layers...")
nested_model.get_layer("nested").dense_1.trainable = False

variable_names_2 = [v.name for v in nested_model.weights]
print("\nvariables: {}".format(variable_names_2))
print("variable ordering changed:", variable_names != variable_names_2)
variables: ['nested/dense_1/kernel:0', 'nested/dense_1/bias:0', 'nested/dense_2/kernel:0', 'nested/dense_2/bias:0']

Changing trainable status of one of the nested layers...

variables: ['nested/dense_2/kernel:0', 'nested/dense_2/bias:0', 'nested/dense_1/kernel:0', 'nested/dense_1/bias:0']
variable ordering changed: True

Chuyển giao ví dụ học tập

Khi tải trọng lượng đã được tinh luyện trước từ HDF5, bạn nên tải trọng lượng vào mô hình điểm kiểm tra ban đầu, sau đó trích xuất các trọng lượng / lớp mong muốn sang một mô hình mới.

Thí dụ:

def create_functional_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
    x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
    outputs = keras.layers.Dense(10, name="predictions")(x)
    return keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")


functional_model = create_functional_model()
functional_model.save_weights("pretrained_weights.h5")

# In a separate program:
pretrained_model = create_functional_model()
pretrained_model.load_weights("pretrained_weights.h5")

# Create a new model by extracting layers from the original model:
extracted_layers = pretrained_model.layers[:-1]
extracted_layers.append(keras.layers.Dense(5, name="dense_3"))
model = keras.Sequential(extracted_layers)
model.summary()
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3 (Dense)              (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________