Ter uma questão? Conecte-se com a comunidade no Fórum TensorFlow Visite o Fórum

Salvar e carregar modelos Keras

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Introdução

Um modelo Keras consiste em vários componentes:

  • A arquitetura, ou configuração, que especifica quais camadas o modelo contém e como elas estão conectadas.
  • Um conjunto de valores de pesos (o "estado do modelo").
  • Um otimizador (definido pela compilação do modelo).
  • Um conjunto de perdas e métricas (definidas pela compilação do modelo ou chamando add_loss() ou add_metric() ).

A API Keras torna possível salvar todas essas peças no disco de uma vez ou apenas salvar seletivamente algumas delas:

  • Salvar tudo em um único arquivo no formato TensorFlow SavedModel (ou no formato Keras H5 mais antigo). Esta é a prática padrão.
  • Salvar a arquitetura / configuração apenas, normalmente como um arquivo JSON.
  • Salvando apenas os valores dos pesos. Isso geralmente é usado ao treinar o modelo.

Vamos dar uma olhada em cada uma dessas opções. Quando você usaria um ou outro e como eles funcionam?

Como salvar e carregar um modelo

Se você tem apenas 10 segundos para ler este guia, aqui está o que você precisa saber.

Salvar um modelo Keras:

model = ...  # Get model (Sequential, Functional Model, or Model subclass)
model.save('path/to/location')

Carregando o modelo de volta:

from tensorflow import keras
model = keras.models.load_model('path/to/location')

Agora, vamos ver os detalhes.

Configurar

import numpy as np
import tensorflow as tf
from tensorflow import keras

Salvar e carregar o modelo inteiro

Você pode salvar um modelo inteiro em um único artefato. Incluirá:

  • A arquitetura / configuração do modelo
  • Os valores de peso do modelo (que foram aprendidos durante o treinamento)
  • As informações de compilação do modelo (se compile() foi chamado)
  • O otimizador e seu estado, se houver (isso permite que você reinicie o treinamento de onde parou)

APIs

Existem dois formatos que você pode usar para salvar um modelo inteiro em disco: o formato TensorFlow SavedModel e o formato Keras H5 mais antigo . O formato recomendado é SavedModel. É o padrão quando você usa model.save() .

Você pode mudar para o formato H5:

  • Passando save_format='h5' para save() .
  • Passando um nome de arquivo que termina em .h5 ou .keras para save() .

Formato SavedModel

SavedModel é o formato de salvamento mais abrangente que salva a arquitetura do modelo, os pesos e os subgráficos rastreados do Tensorflow das funções de chamada. Isso permite que o Keras restaure tanto as camadas integradas quanto os objetos personalizados.

Exemplo:

def get_model():
    # Create a simple model.
    inputs = keras.Input(shape=(32,))
    outputs = keras.layers.Dense(1)(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer="adam", loss="mean_squared_error")
    return model


model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model')` creates a SavedModel folder `my_model`.
model.save("my_model")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_model")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)
4/4 [==============================] - 1s 2ms/step - loss: 0.9237
INFO:tensorflow:Assets written to: my_model/assets
4/4 [==============================] - 0s 1ms/step - loss: 0.7730
<tensorflow.python.keras.callbacks.History at 0x7fd0a032a390>

O que o SavedModel contém

Chamar model.save('my_model') cria uma pasta chamada my_model , contendo o seguinte:

ls my_model
assets  saved_model.pb  variables

A arquitetura do modelo e a configuração de treinamento (incluindo o otimizador, perdas e métricas) são armazenados em saved_model.pb . Os pesos são salvos no diretório variables/ .

Para obter informações detalhadas sobre o formato SavedModel, consulte o guia SavedModel ( formato SavedModel em disco ) .

Como SavedModel lida com objetos personalizados

Ao salvar o modelo e suas camadas, o formato SavedModel armazena o nome da classe, função de chamada , perdas e pesos (e a configuração, se implementada). A função de chamada define o gráfico de computação do modelo / camada.

Na ausência da configuração do modelo / camada, a função de chamada é usada para criar um modelo que existe como o modelo original, que pode ser treinado, avaliado e usado para inferência.

No entanto, é sempre uma boa prática para definir a get_config e from_config métodos ao escrever um modelo personalizado ou classe camada. Isso permite que você atualize facilmente o cálculo posteriormente, se necessário. Consulte a seção sobre objetos personalizados para obter mais informações.

Exemplo:

class CustomModel(keras.Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.hidden_units = hidden_units
        self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]

    def call(self, inputs):
        x = inputs
        for layer in self.dense_layers:
            x = layer(x)
        return x

    def get_config(self):
        return {"hidden_units": self.hidden_units}

    @classmethod
    def from_config(cls, config):
        return cls(**config)


model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs = model(input_arr)
model.save("my_model")

# Option 1: Load with the custom_object argument.
loaded_1 = keras.models.load_model(
    "my_model", custom_objects={"CustomModel": CustomModel}
)

# Option 2: Load without the CustomModel class.

# Delete the custom-defined model class to ensure that the loader does not have
# access to it.
del CustomModel

loaded_2 = keras.models.load_model("my_model")
np.testing.assert_allclose(loaded_1(input_arr), outputs)
np.testing.assert_allclose(loaded_2(input_arr), outputs)

print("Original model:", model)
print("Model Loaded with custom objects:", loaded_1)
print("Model loaded without the custom object class:", loaded_2)
INFO:tensorflow:Assets written to: my_model/assets
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Original model: <__main__.CustomModel object at 0x7fd0a035bcf8>
Model Loaded with custom objects: <__main__.CustomModel object at 0x7fd1455d04e0>
Model loaded without the custom object class: <tensorflow.python.keras.saving.saved_model.load.CustomModel object at 0x7fd14553af98>

O primeiro modelo carregado é carregado usando a classe config e CustomModel . O segundo modelo é carregado criando dinamicamente a classe de modelo que atua como o modelo original.

Configurando o SavedModel

Novo no TensoFlow 2.4 O argumento save_traces foi adicionado a model.save , que permite alternar o rastreamento da função SavedModel. As funções são salvas para permitir que o Keras recarregue os objetos personalizados sem as definições da classe original, portanto, quando save_traces=False , todos os objetos personalizados devem ter os métodos from_config / get_config definidos. Ao carregar, os objetos personalizados devem ser passados ​​para o argumento custom_objects . save_traces=False reduz o espaço em disco usado pelo SavedModel e economiza tempo.

Formato Keras H5

Keras também suporta salvar um único arquivo HDF5 contendo a arquitetura do modelo, valores de pesos e informações de compile() . É uma alternativa leve para SavedModel.

Exemplo:

model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.
model.save("my_h5_model.h5")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_h5_model.h5")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)
4/4 [==============================] - 0s 1ms/step - loss: 1.0153
4/4 [==============================] - 0s 1ms/step - loss: 0.9104
<tensorflow.python.keras.callbacks.History at 0x7fd1455c66a0>

Limitações

Em comparação com o formato SavedModel, há duas coisas que não são incluídas no arquivo H5:

  • Perdas externas e métricas adicionadas via model.add_loss() e model.add_metric() não são salvas (ao contrário de SavedModel). Se você tiver essas perdas e métricas em seu modelo e quiser retomar o treinamento, será necessário adicionar essas perdas de volta depois de carregar o modelo. Observe que isso não se aplica a perdas / métricas criadas dentro das camadas via self.add_loss() & self.add_metric() . Enquanto a camada é carregada, essas perdas e métricas são mantidas, pois fazem parte do método de call da camada.
  • O gráfico de computação de objetos personalizados , como camadas personalizadas, não está incluído no arquivo salvo. No momento do carregamento, Keras precisará acessar as classes / funções Python desses objetos para reconstruir o modelo. Veja objetos personalizados .

Salvando a arquitetura

A configuração (ou arquitetura) do modelo especifica quais camadas o modelo contém e como essas camadas são conectadas *. Se você tiver a configuração de um modelo, ele poderá ser criado com um estado recém-inicializado para os pesos e sem informações de compilação.

* Observe que isso se aplica apenas a modelos definidos usando os modelos funcionais ou sequenciais apis sem subclasses.

Configuração de um modelo sequencial ou modelo de API funcional

Esses tipos de modelos são gráficos explícitos de camadas: sua configuração está sempre disponível de forma estruturada.

APIs

get_config() e from_config()

Chamar config = model.get_config() retornará um config = model.get_config() Python contendo a configuração do modelo. O mesmo modelo pode então ser reconstruído via Sequential.from_config(config) (para um modelo Sequential ) ou Model.from_config(config) (para um modelo Functional API).

O mesmo fluxo de trabalho também funciona para qualquer camada serializável.

Exemplo de camada:

layer = keras.layers.Dense(3, activation="relu")
layer_config = layer.get_config()
new_layer = keras.layers.Dense.from_config(layer_config)

Exemplo de modelo sequencial:

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
config = model.get_config()
new_model = keras.Sequential.from_config(config)

Exemplo de modelo funcional:

inputs = keras.Input((32,))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
config = model.get_config()
new_model = keras.Model.from_config(config)

to_json() e tf.keras.models.model_from_json()

Isso é semelhante a get_config / from_config , exceto que transforma o modelo em uma string JSON, que pode então ser carregada sem a classe de modelo original. Também é específico para modelos, não é para camadas.

Exemplo:

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
json_config = model.to_json()
new_model = keras.models.model_from_json(json_config)

Objetos personalizados

Modelos e camadas

A arquitetura de modelos e camadas com subclasses são definidos nos métodos __init__ e call . Eles são considerados bytecode Python, que não pode ser serializado em uma configuração compatível com JSON - você pode tentar serializar o bytecode (por exemplo, via pickle ), mas é completamente inseguro e significa que seu modelo não pode ser carregado em um sistema diferente.

Para salvar / carregar um modelo com camadas personalizadas ou um modelo com subclasse, você deve sobrescrever os métodos get_config e opcionalmente from_config . Além disso, você deve usar o registro do objeto personalizado para que Keras fique ciente dele.

Funções personalizadas

Funções personalizadas (por exemplo, perda de ativação ou inicialização) não precisam de um método get_config . O nome da função é suficiente para carregar, desde que seja registrado como um objeto personalizado.

Carregando o gráfico TensorFlow apenas

É possível carregar o gráfico TensorFlow gerado pelo Keras. Se você fizer isso, não precisará fornecer nenhum custom_objects . Você pode fazer assim:

model.save("my_model")
tensorflow_graph = tf.saved_model.load("my_model")
x = np.random.uniform(size=(4, 32)).astype(np.float32)
predicted = tensorflow_graph(x).numpy()
INFO:tensorflow:Assets written to: my_model/assets

Observe que esse método tem várias desvantagens:

  • Por motivos de rastreabilidade, você deve sempre ter acesso aos objetos personalizados que foram usados. Você não gostaria de colocar em produção um modelo que não pode recriar.
  • O objeto retornado por tf.saved_model.load não é um modelo Keras. Portanto, não é tão fácil de usar. Por exemplo, você não terá acesso a .predict() ou .fit()

Mesmo que seu uso seja desencorajado, ele pode ajudá-lo se você estiver em uma situação difícil, por exemplo, se você perdeu o código de seus objetos personalizados ou teve problemas para carregar o modelo com tf.keras.models.load_model() .

Você pode descobrir mais na página sobre tf.saved_model.load

Definindo os métodos de configuração

Especificações:

  • get_config deve retornar um dicionário JSON serializável para ser compatível com a arquitetura Keras - e APIs de economia de modelo.
  • from_config(config) ( classmethod ) deve retornar uma nova camada ou objeto de modelo que é criado a partir da configuração. A implementação padrão retorna cls(**config) .

Exemplo:

class CustomLayer(keras.layers.Layer):
    def __init__(self, a):
        self.var = tf.Variable(a, name="var_a")

    def call(self, inputs, training=False):
        if training:
            return inputs * self.var
        else:
            return inputs

    def get_config(self):
        return {"a": self.var.numpy()}

    # There's actually no need to define `from_config` here, since returning
    # `cls(**config)` is the default behavior.
    @classmethod
    def from_config(cls, config):
        return cls(**config)


layer = CustomLayer(5)
layer.var.assign(2)

serialized_layer = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
    serialized_layer, custom_objects={"CustomLayer": CustomLayer}
)

Registrando o objeto personalizado

Keras mantém uma nota de qual classe gerou a configuração. No exemplo acima, tf.keras.layers.serialize gera uma forma serializada da camada personalizada:

{'class_name': 'CustomLayer', 'config': {'a': 2} }

Keras mantém uma lista principal de todas as classes integradas de camada, modelo, otimizador e métrica, que é usada para encontrar a classe correta para chamar from_config . Se a classe não puder ser encontrada, um erro será gerado ( Value Error: Unknown layer ). Existem algumas maneiras de registrar classes personalizadas nesta lista:

  1. Configurando o argumento custom_objects na função de carregamento. (veja o exemplo na seção acima "Definindo os métodos de configuração")
  2. tf.keras.utils.custom_object_scope ou tf.keras.utils.CustomObjectScope
  3. tf.keras.utils.register_keras_serializable

Camada personalizada e exemplo de função

class CustomLayer(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(CustomLayer, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(CustomLayer, self).get_config()
        config.update({"units": self.units})
        return config


def custom_activation(x):
    return tf.nn.tanh(x) ** 2


# Make a model with the CustomLayer and custom_activation
inputs = keras.Input((32,))
x = CustomLayer(32)(inputs)
outputs = keras.layers.Activation(custom_activation)(x)
model = keras.Model(inputs, outputs)

# Retrieve the config
config = model.get_config()

# At loading time, register the custom objects with a `custom_object_scope`:
custom_objects = {"CustomLayer": CustomLayer, "custom_activation": custom_activation}
with keras.utils.custom_object_scope(custom_objects):
    new_model = keras.Model.from_config(config)

Clonagem de modelo na memória

Você também pode fazer a clonagem na memória de um modelo via tf.keras.models.clone_model() . Isso é equivalente a obter a configuração e, em seguida, recriar o modelo a partir de sua configuração (portanto, não preserva as informações de compilação ou os valores dos pesos das camadas).

Exemplo:

with keras.utils.custom_object_scope(custom_objects):
    new_model = keras.models.clone_model(model)

Salvar e carregar apenas os valores de peso do modelo

Você pode escolher salvar e carregar apenas os pesos de um modelo. Isso pode ser útil se:

  • Você só precisa do modelo para inferência: neste caso, você não precisa reiniciar o treinamento, portanto, não precisa das informações de compilação ou do estado do otimizador.
  • Você está fazendo transferência de aprendizagem: neste caso, você estará treinando um novo modelo reutilizando o estado de um modelo anterior, portanto, não precisa das informações de compilação do modelo anterior.

APIs para transferência de peso na memória

Pesos podem ser copiados entre objetos diferentes usando get_weights e set_weights :

Exemplos abaixo.

Transferindo pesos de uma camada para outra, na memória

def create_layer():
    layer = keras.layers.Dense(64, activation="relu", name="dense_2")
    layer.build((None, 784))
    return layer


layer_1 = create_layer()
layer_2 = create_layer()

# Copy weights from layer 1 to layer 2
layer_2.set_weights(layer_1.get_weights())

Transferir pesos de um modelo para outro modelo com arquitetura compatível, na memória

# Create a simple functional model
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

# Define a subclassed model with the same architecture
class SubclassedModel(keras.Model):
    def __init__(self, output_dim, name=None):
        super(SubclassedModel, self).__init__(name=name)
        self.output_dim = output_dim
        self.dense_1 = keras.layers.Dense(64, activation="relu", name="dense_1")
        self.dense_2 = keras.layers.Dense(64, activation="relu", name="dense_2")
        self.dense_3 = keras.layers.Dense(output_dim, name="predictions")

    def call(self, inputs):
        x = self.dense_1(inputs)
        x = self.dense_2(x)
        x = self.dense_3(x)
        return x

    def get_config(self):
        return {"output_dim": self.output_dim, "name": self.name}


subclassed_model = SubclassedModel(10)
# Call the subclassed model once to create the weights.
subclassed_model(tf.ones((1, 784)))

# Copy weights from functional_model to subclassed_model.
subclassed_model.set_weights(functional_model.get_weights())

assert len(functional_model.weights) == len(subclassed_model.weights)
for a, b in zip(functional_model.weights, subclassed_model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

O caso das camadas sem estado

Como as camadas sem estado não alteram a ordem ou o número de pesos, os modelos podem ter arquiteturas compatíveis, mesmo se houver camadas sem estado extras / ausentes.

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)

# Add a dropout layer, which does not contain any weights.
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model_with_dropout = keras.Model(
    inputs=inputs, outputs=outputs, name="3_layer_mlp"
)

functional_model_with_dropout.set_weights(functional_model.get_weights())

APIs para salvar pesos no disco e carregá-los de volta

Pesos podem ser salvos em disco chamando model.save_weights nos seguintes formatos:

  • Ponto de verificação do TensorFlow
  • HDF5

O formato padrão para model.save_weights é o ponto de verificação do TensorFlow. Existem duas maneiras de especificar o formato de salvamento:

  1. argumento save_format : defina o valor para save_format="tf" ou save_format="h5" .
  2. argumento de path : se o caminho terminar com .h5 ou .hdf5 , o formato HDF5 será usado. Outros sufixos resultarão em um ponto de verificação do save_format menos que save_format seja definido.

Também há uma opção de recuperar pesos como matrizes numpy na memória. Cada API tem seus prós e contras, que são detalhados a seguir.

Formato de ponto de verificação TF

Exemplo:

# Runnable example
sequential_model = keras.Sequential(
    [
        keras.Input(shape=(784,), name="digits"),
        keras.layers.Dense(64, activation="relu", name="dense_1"),
        keras.layers.Dense(64, activation="relu", name="dense_2"),
        keras.layers.Dense(10, name="predictions"),
    ]
)
sequential_model.save_weights("ckpt")
load_status = sequential_model.load_weights("ckpt")

# `assert_consumed` can be used as validation that all variable values have been
# restored from the checkpoint. See `tf.train.Checkpoint.restore` for other
# methods in the Status object.
load_status.assert_consumed()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd0a065f128>

Detalhes de formato

O formato do TensorFlow Checkpoint salva e restaura os pesos usando nomes de atributos de objetos. Por exemplo, considere a camada tf.keras.layers.Dense . A camada contém dois pesos: dense.kernel e dense.bias . Quando a camada é salva no formato tf , o ponto de verificação resultante contém as chaves "kernel" e "bias" e seus valores de peso correspondentes. Para obter mais informações, consulte "Carregando a mecânica" no guia do TF Checkpoint .

Observe que o atributo / borda do gráfico é nomeado após o nome usado no objeto pai, não o nome da variável . Considere o CustomLayer no exemplo abaixo. A variável CustomLayer.var é salva com "var" como parte da chave, não "var_a" .

class CustomLayer(keras.layers.Layer):
    def __init__(self, a):
        self.var = tf.Variable(a, name="var_a")


layer = CustomLayer(5)
layer_ckpt = tf.train.Checkpoint(layer=layer).save("custom_layer")

ckpt_reader = tf.train.load_checkpoint(layer_ckpt)

ckpt_reader.get_variable_to_dtype_map()
{'save_counter/.ATTRIBUTES/VARIABLE_VALUE': tf.int64,
 '_CHECKPOINTABLE_OBJECT_GRAPH': tf.string,
 'layer/var/.ATTRIBUTES/VARIABLE_VALUE': tf.int32}

Exemplo de transferência de aprendizagem

Essencialmente, desde que dois modelos tenham a mesma arquitetura, eles podem compartilhar o mesmo ponto de verificação.

Exemplo:

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

# Extract a portion of the functional model defined in the Setup section.
# The following lines produce a new model that excludes the final output
# layer of the functional model.
pretrained = keras.Model(
    functional_model.inputs, functional_model.layers[-1].input, name="pretrained_model"
)
# Randomly assign "trained" weights.
for w in pretrained.weights:
    w.assign(tf.random.normal(w.shape))
pretrained.save_weights("pretrained_ckpt")
pretrained.summary()

# Assume this is a separate program where only 'pretrained_ckpt' exists.
# Create a new functional model with a different output dimension.
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(5, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="new_model")

# Load the weights from pretrained_ckpt into model.
model.load_weights("pretrained_ckpt")

# Check that all of the pretrained weights have been loaded.
for a, b in zip(pretrained.weights, model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

print("\n", "-" * 50)
model.summary()

# Example 2: Sequential model
# Recreate the pretrained model, and load the saved weights.
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
pretrained_model = keras.Model(inputs=inputs, outputs=x, name="pretrained")

# Sequential example:
model = keras.Sequential([pretrained_model, keras.layers.Dense(5, name="predictions")])
model.summary()

pretrained_model.load_weights("pretrained_ckpt")

# Warning! Calling `model.load_weights('pretrained_ckpt')` won't throw an error,
# but will *not* work as expected. If you inspect the weights, you'll see that
# none of the weights will have loaded. `pretrained_model.load_weights()` is the
# correct method to call.
Model: "pretrained_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
=================================================================
Total params: 54,400
Trainable params: 54,400
Non-trainable params: 0
_________________________________________________________________

 --------------------------------------------------
Model: "new_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
predictions (Dense)          (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
pretrained (Functional)      (None, 64)                54400     
_________________________________________________________________
predictions (Dense)          (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd144b20b38>

Geralmente, é recomendado manter a mesma API para construir modelos. Se você alternar entre Sequencial e Funcional, ou Funcional e subclasse, etc., sempre reconstrua o modelo pré-treinado e carregue os pesos pré-treinados nesse modelo.

A próxima pergunta é: como os pesos podem ser salvos e carregados em modelos diferentes se as arquiteturas de modelo são bastante diferentes? A solução é usar tf.train.Checkpoint para salvar e restaurar as camadas / variáveis ​​exatas.

Exemplo:

# Create a subclassed model that essentially uses functional_model's first
# and last layers.
# First, save the weights of functional_model's first and last dense layers.
first_dense = functional_model.layers[1]
last_dense = functional_model.layers[-1]
ckpt_path = tf.train.Checkpoint(
    dense=first_dense, kernel=last_dense.kernel, bias=last_dense.bias
).save("ckpt")

# Define the subclassed model.
class ContrivedModel(keras.Model):
    def __init__(self):
        super(ContrivedModel, self).__init__()
        self.first_dense = keras.layers.Dense(64)
        self.kernel = self.add_variable("kernel", shape=(64, 10))
        self.bias = self.add_variable("bias", shape=(10,))

    def call(self, inputs):
        x = self.first_dense(inputs)
        return tf.matmul(x, self.kernel) + self.bias


model = ContrivedModel()
# Call model on inputs to create the variables of the dense layer.
_ = model(tf.ones((1, 784)))

# Create a Checkpoint with the same structure as before, and load the weights.
tf.train.Checkpoint(
    dense=model.first_dense, kernel=model.kernel, bias=model.bias
).restore(ckpt_path).assert_consumed()
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:2281: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd1455c6cc0>

Formato HDF5

O formato HDF5 contém pesos agrupados por nomes de camadas. Os pesos são listas ordenadas concatenando a lista de pesos treináveis ​​com a lista de pesos não treináveis ​​(o mesmo que layer.weights ). Assim, um modelo pode usar um ponto de verificação hdf5 se tiver as mesmas camadas e status treináveis ​​salvos no ponto de verificação.

Exemplo:

# Runnable example
sequential_model = keras.Sequential(
    [
        keras.Input(shape=(784,), name="digits"),
        keras.layers.Dense(64, activation="relu", name="dense_1"),
        keras.layers.Dense(64, activation="relu", name="dense_2"),
        keras.layers.Dense(10, name="predictions"),
    ]
)
sequential_model.save_weights("weights.h5")
sequential_model.load_weights("weights.h5")

Observe que a alteração de layer.trainable pode resultar em uma ordem diferente de layer.weights quando o modelo contém camadas aninhadas.

class NestedDenseLayer(keras.layers.Layer):
    def __init__(self, units, name=None):
        super(NestedDenseLayer, self).__init__(name=name)
        self.dense_1 = keras.layers.Dense(units, name="dense_1")
        self.dense_2 = keras.layers.Dense(units, name="dense_2")

    def call(self, inputs):
        return self.dense_2(self.dense_1(inputs))


nested_model = keras.Sequential([keras.Input((784,)), NestedDenseLayer(10, "nested")])
variable_names = [v.name for v in nested_model.weights]
print("variables: {}".format(variable_names))

print("\nChanging trainable status of one of the nested layers...")
nested_model.get_layer("nested").dense_1.trainable = False

variable_names_2 = [v.name for v in nested_model.weights]
print("\nvariables: {}".format(variable_names_2))
print("variable ordering changed:", variable_names != variable_names_2)
variables: ['nested/dense_1/kernel:0', 'nested/dense_1/bias:0', 'nested/dense_2/kernel:0', 'nested/dense_2/bias:0']

Changing trainable status of one of the nested layers...

variables: ['nested/dense_2/kernel:0', 'nested/dense_2/bias:0', 'nested/dense_1/kernel:0', 'nested/dense_1/bias:0']
variable ordering changed: True

Exemplo de aprendizagem de transferência

Ao carregar pesos pré-treinados de HDF5, é recomendado carregar os pesos no modelo de checkpoint original e, em seguida, extrair os pesos / camadas desejados em um novo modelo.

Exemplo:

def create_functional_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
    x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
    outputs = keras.layers.Dense(10, name="predictions")(x)
    return keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")


functional_model = create_functional_model()
functional_model.save_weights("pretrained_weights.h5")

# In a separate program:
pretrained_model = create_functional_model()
pretrained_model.load_weights("pretrained_weights.h5")

# Create a new model by extracting layers from the original model:
extracted_layers = pretrained_model.layers[:-1]
extracted_layers.append(keras.layers.Dense(5, name="dense_3"))
model = keras.Sequential(extracted_layers)
model.summary()
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3 (Dense)              (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________