Salva e carica i modelli Keras

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica taccuino

introduzione

Un modello Keras è costituito da più componenti:

  • L'architettura, o configurazione, che specifica quali layer contiene il modello e come sono collegati.
  • Un insieme di valori di pesi (lo "stato del modello").
  • Un ottimizzatore (definito compilando il modello).
  • Un insieme di perdite e metriche (definite compilando il modello o chiamando add_loss() o add_metric() ).

L'API di Keras consente di salvare tutti questi pezzi su disco contemporaneamente o di salvarne solo alcuni selettivamente:

  • Salvataggio di tutto in un unico archivio nel formato TensorFlow SavedModel (o nel vecchio formato Keras H5). Questa è la pratica standard.
  • Salvataggio solo dell'architettura / configurazione, in genere come file JSON.
  • Salvataggio solo dei valori dei pesi. Viene generalmente utilizzato durante l'addestramento del modello.

Diamo un'occhiata a ciascuna di queste opzioni. Quando useresti l'uno o l'altro e come funzionano?

Come salvare e caricare un modello

Se hai solo 10 secondi per leggere questa guida, ecco cosa devi sapere.

Salvataggio di un modello Keras:

model = ...  # Get model (Sequential, Functional Model, or Model subclass)
model.save('path/to/location')

Ricaricare il modello:

from tensorflow import keras
model = keras.models.load_model('path/to/location')

Ora, esaminiamo i dettagli.

Impostare

import numpy as np
import tensorflow as tf
from tensorflow import keras

Salvataggio e caricamento dell'intero modello

Puoi salvare un intero modello su un singolo artefatto. Comprenderà:

  • L'architettura / configurazione del modello
  • I valori di peso del modello (che sono stati appresi durante l'allenamento)
  • Le informazioni di compilazione del modello (se è stata chiamata compile() )
  • L'ottimizzatore e il suo stato, se presente (questo ti consente di ricominciare l'allenamento da dove eri rimasto)

API

Esistono due formati che puoi utilizzare per salvare un intero modello su disco: il formato TensorFlow SavedModel e il vecchio formato Keras H5 . Il formato consigliato è SavedModel. È l'impostazione predefinita quando si utilizza model.save() .

Puoi passare al formato H5:

  • Passando save_format='h5' a save() .
  • Passaggio di un nome file che termina con .h5 o .keras per save() .

Formato SavedModel

SavedModel è il formato di salvataggio più completo che salva l'architettura del modello, i pesi e i sottografi Tensorflow tracciati delle funzioni di chiamata. Ciò consente a Keras di ripristinare sia i livelli incorporati che gli oggetti personalizzati.

Esempio:

def get_model():
    # Create a simple model.
    inputs = keras.Input(shape=(32,))
    outputs = keras.layers.Dense(1)(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer="adam", loss="mean_squared_error")
    return model


model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model')` creates a SavedModel folder `my_model`.
model.save("my_model")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_model")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)
4/4 [==============================] - 1s 2ms/step - loss: 0.9237
INFO:tensorflow:Assets written to: my_model/assets
4/4 [==============================] - 0s 1ms/step - loss: 0.7730
<tensorflow.python.keras.callbacks.History at 0x7fd0a032a390>

Cosa contiene il SavedModel

La chiamata model.save('my_model') crea una cartella denominata my_model , contenente quanto segue:

ls my_model
assets  saved_model.pb  variables

L'architettura del modello e la configurazione dell'addestramento (inclusi l'ottimizzatore, le perdite e le metriche) sono archiviate in saved_model.pb . I pesi vengono salvati nella directory variables/ .

Per informazioni dettagliate sul formato SavedModel, vedere la guida SavedModel ( Il formato SavedModel su disco ) .

In che modo SavedModel gestisce gli oggetti personalizzati

Durante il salvataggio del modello e dei suoi livelli, il formato SavedModel memorizza il nome della classe, la funzione di chiamata , le perdite e i pesi (e la configurazione, se implementata). La funzione call definisce il grafico di calcolo del modello / livello.

In assenza della configurazione del modello / livello, la funzione di chiamata viene utilizzata per creare un modello che esiste come il modello originale che può essere addestrato, valutato e utilizzato per l'inferenza.

Tuttavia, è sempre una buona pratica definire i metodi get_config e from_config quando si scrive un modello personalizzato o una classe layer. Ciò consente di aggiornare facilmente il calcolo in un secondo momento, se necessario. Vedere la sezione sugli oggetti personalizzati per ulteriori informazioni.

Esempio:

class CustomModel(keras.Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.hidden_units = hidden_units
        self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]

    def call(self, inputs):
        x = inputs
        for layer in self.dense_layers:
            x = layer(x)
        return x

    def get_config(self):
        return {"hidden_units": self.hidden_units}

    @classmethod
    def from_config(cls, config):
        return cls(**config)


model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs = model(input_arr)
model.save("my_model")

# Option 1: Load with the custom_object argument.
loaded_1 = keras.models.load_model(
    "my_model", custom_objects={"CustomModel": CustomModel}
)

# Option 2: Load without the CustomModel class.

# Delete the custom-defined model class to ensure that the loader does not have
# access to it.
del CustomModel

loaded_2 = keras.models.load_model("my_model")
np.testing.assert_allclose(loaded_1(input_arr), outputs)
np.testing.assert_allclose(loaded_2(input_arr), outputs)

print("Original model:", model)
print("Model Loaded with custom objects:", loaded_1)
print("Model loaded without the custom object class:", loaded_2)
INFO:tensorflow:Assets written to: my_model/assets
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Original model: <__main__.CustomModel object at 0x7fd0a035bcf8>
Model Loaded with custom objects: <__main__.CustomModel object at 0x7fd1455d04e0>
Model loaded without the custom object class: <tensorflow.python.keras.saving.saved_model.load.CustomModel object at 0x7fd14553af98>

Il primo modello caricato viene caricato utilizzando la classe config e CustomModel . Il secondo modello viene caricato creando dinamicamente la classe del modello che agisce come il modello originale.

Configurazione di SavedModel

Novità in TensoFlow 2.4 L'argomento save_traces è stato aggiunto a model.save , che consente di attivare o disattivare la funzione di tracciamento di SavedModel. Le funzioni vengono salvate per consentire a Kera di ricaricare gli oggetti personalizzati senza le definizioni della classe originale, quindi quando save_traces=False , tutti gli oggetti personalizzati devono avere i metodi get_config / from_config definiti. Durante il caricamento, gli oggetti personalizzati devono essere passati all'argomento custom_objects . save_traces=False riduce lo spazio su disco utilizzato da SavedModel e fa risparmiare tempo.

Formato Keras H5

Keras supporta anche il salvataggio di un singolo file HDF5 contenente l'architettura del modello, i valori dei pesi e le informazioni compile() . È un'alternativa leggera a SavedModel.

Esempio:

model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.
model.save("my_h5_model.h5")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_h5_model.h5")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)
4/4 [==============================] - 0s 1ms/step - loss: 1.0153
4/4 [==============================] - 0s 1ms/step - loss: 0.9104
<tensorflow.python.keras.callbacks.History at 0x7fd1455c66a0>

Limitazioni

Rispetto al formato SavedModel, ci sono due cose che non vengono incluse nel file H5:

  • Le perdite esterne e le metriche aggiunte tramite model.add_loss() e model.add_metric() non vengono salvate (a differenza di SavedModel). Se si dispone di tali perdite e metriche sul modello e si desidera riprendere l'allenamento, è necessario aggiungere nuovamente queste perdite dopo aver caricato il modello. Nota che questo non si applica alle perdite / metriche create all'interno dei livelli tramite self.add_loss() e self.add_metric() . Finché il livello viene caricato, queste perdite e metriche vengono mantenute, poiché fanno parte del metodo di call del livello.
  • Il grafico di calcolo degli oggetti personalizzati come i layer personalizzati non è incluso nel file salvato. Al momento del caricamento, Keras avrà bisogno di accedere alle classi / funzioni Python di questi oggetti per ricostruire il modello. Vedi Oggetti personalizzati .

Salvare l'architettura

La configurazione (o architettura) del modello specifica quali layer contiene il modello e come questi layer sono collegati *. Se si dispone della configurazione di un modello, è possibile creare il modello con uno stato appena inizializzato per i pesi e senza informazioni di compilazione.

* Nota che questo si applica solo ai modelli definiti utilizzando le API funzionali o sequenziali non sottoclassate.

Configurazione di un modello sequenziale o di un modello API funzionale

Questi tipi di modelli sono grafici espliciti di livelli: la loro configurazione è sempre disponibile in una forma strutturata.

API

get_config() e from_config()

Chiamare config = model.get_config() restituirà un dict Python contenente la configurazione del modello. Lo stesso modello può quindi essere ricostruito tramite Sequential.from_config(config) (per un modello Sequential ) o Model.from_config(config) (per un modello API funzionale).

Lo stesso flusso di lavoro funziona anche per qualsiasi livello serializzabile.

Esempio di livello:

layer = keras.layers.Dense(3, activation="relu")
layer_config = layer.get_config()
new_layer = keras.layers.Dense.from_config(layer_config)

Esempio di modello sequenziale:

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
config = model.get_config()
new_model = keras.Sequential.from_config(config)

Esempio di modello funzionale:

inputs = keras.Input((32,))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
config = model.get_config()
new_model = keras.Model.from_config(config)

to_json() e tf.keras.models.model_from_json()

È simile a get_config / from_config , tranne per il fatto che trasforma il modello in una stringa JSON, che può quindi essere caricata senza la classe del modello originale. È anche specifico per i modelli, non è pensato per i livelli.

Esempio:

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
json_config = model.to_json()
new_model = keras.models.model_from_json(json_config)

Oggetti personalizzati

Modelli e strati

L'architettura dei modelli e dei livelli sottoclasse è definita nei metodi __init__ e call . Sono considerati bytecode Python, che non può essere serializzato in una configurazione compatibile con JSON: potresti provare a serializzare il bytecode (ad esempio tramite pickle ), ma è completamente pericoloso e significa che il tuo modello non può essere caricato su un sistema diverso.

Per salvare / caricare un modello con layer definiti dall'utente o un modello sottoclasse, è necessario sovrascrivere i get_config e facoltativamente from_config . Inoltre, dovresti usare registra l'oggetto personalizzato in modo che Keras lo sappia.

Funzioni personalizzate

Le funzioni definite dall'utente (ad esempio, perdita di attivazione o inizializzazione) non richiedono un metodo get_config . Il nome della funzione è sufficiente per il caricamento purché sia ​​registrato come oggetto personalizzato.

Caricamento solo del grafico TensorFlow

È possibile caricare il grafico TensorFlow generato da Keras. In tal caso, non sarà necessario fornire alcun custom_objects . Puoi farlo in questo modo:

model.save("my_model")
tensorflow_graph = tf.saved_model.load("my_model")
x = np.random.uniform(size=(4, 32)).astype(np.float32)
predicted = tensorflow_graph(x).numpy()
INFO:tensorflow:Assets written to: my_model/assets

Nota che questo metodo ha diversi inconvenienti:

  • Per motivi di tracciabilità, dovresti sempre avere accesso agli oggetti personalizzati che sono stati utilizzati. Non vorresti mettere in produzione un modello che non puoi ricreare.
  • L'oggetto restituito da tf.saved_model.load non è un modello Keras. Quindi non è così facile da usare. Ad esempio, non avrai accesso a .predict() o .fit()

Anche se il suo utilizzo è scoraggiato, può aiutarti se sei in una situazione difficile, ad esempio, se hai perso il codice dei tuoi oggetti personalizzati o hai problemi a caricare il modello con tf.keras.models.load_model() .

Puoi saperne di più nella pagina su tf.saved_model.load

Definizione dei metodi di configurazione

Specifiche:

  • get_config dovrebbe restituire un dizionario serializzabile JSON per essere compatibile con le API di Keras per il salvataggio dell'architettura e del modello.
  • from_config(config) ( classmethod ) dovrebbe restituire un nuovo oggetto layer o modello creato dalla configurazione. L'implementazione predefinita restituisce cls(**config) .

Esempio:

class CustomLayer(keras.layers.Layer):
    def __init__(self, a):
        self.var = tf.Variable(a, name="var_a")

    def call(self, inputs, training=False):
        if training:
            return inputs * self.var
        else:
            return inputs

    def get_config(self):
        return {"a": self.var.numpy()}

    # There's actually no need to define `from_config` here, since returning
    # `cls(**config)` is the default behavior.
    @classmethod
    def from_config(cls, config):
        return cls(**config)


layer = CustomLayer(5)
layer.var.assign(2)

serialized_layer = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
    serialized_layer, custom_objects={"CustomLayer": CustomLayer}
)

Registrazione dell'oggetto personalizzato

Keras tiene nota di quale classe ha generato la configurazione. tf.keras.layers.serialize sopra, tf.keras.layers.serialize genera un modulo serializzato del livello personalizzato:

{'class_name': 'CustomLayer', 'config': {'a': 2} }

Keras mantiene un elenco principale di tutte le classi incorporate di layer, modelli, ottimizzatori e metriche, che viene utilizzato per trovare la classe corretta da chiamare from_config . Se la classe non può essere trovata, viene generato un errore ( Value Error: Unknown layer ). Esistono alcuni modi per registrare classi personalizzate in questo elenco:

  1. Impostazione dell'argomento custom_objects nella funzione di caricamento. (vedere l'esempio nella sezione precedente "Definizione dei metodi di configurazione")
  2. tf.keras.utils.custom_object_scope o tf.keras.utils.CustomObjectScope
  3. tf.keras.utils.register_keras_serializable

Livello personalizzato e esempio di funzione

class CustomLayer(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(CustomLayer, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(CustomLayer, self).get_config()
        config.update({"units": self.units})
        return config


def custom_activation(x):
    return tf.nn.tanh(x) ** 2


# Make a model with the CustomLayer and custom_activation
inputs = keras.Input((32,))
x = CustomLayer(32)(inputs)
outputs = keras.layers.Activation(custom_activation)(x)
model = keras.Model(inputs, outputs)

# Retrieve the config
config = model.get_config()

# At loading time, register the custom objects with a `custom_object_scope`:
custom_objects = {"CustomLayer": CustomLayer, "custom_activation": custom_activation}
with keras.utils.custom_object_scope(custom_objects):
    new_model = keras.Model.from_config(config)

Clonazione del modello in memoria

Puoi anche eseguire la clonazione in memoria di un modello tramite tf.keras.models.clone_model() . Ciò equivale a ottenere la configurazione e quindi a ricreare il modello dalla sua configurazione (quindi non conserva le informazioni di compilazione oi valori dei pesi dei livelli).

Esempio:

with keras.utils.custom_object_scope(custom_objects):
    new_model = keras.models.clone_model(model)

Salvataggio e caricamento dei soli valori di peso del modello

Puoi scegliere di salvare e caricare solo i pesi di un modello. Questo può essere utile se:

  • Hai solo bisogno del modello per l'inferenza: in questo caso non avrai bisogno di riavviare l'addestramento, quindi non hai bisogno delle informazioni di compilazione o dello stato dell'ottimizzatore.
  • Stai facendo transfer learning: in questo caso addestrerai un nuovo modello riutilizzando lo stato di un modello precedente, quindi non avrai bisogno delle informazioni di compilazione del modello precedente.

API per il trasferimento del peso in memoria

I pesi possono essere copiati tra diversi oggetti utilizzando get_weights e set_weights :

Esempi di seguito.

Trasferimento di pesi da uno strato all'altro, in memoria

def create_layer():
    layer = keras.layers.Dense(64, activation="relu", name="dense_2")
    layer.build((None, 784))
    return layer


layer_1 = create_layer()
layer_2 = create_layer()

# Copy weights from layer 1 to layer 2
layer_2.set_weights(layer_1.get_weights())

Trasferimento di pesi da un modello a un altro modello con un'architettura compatibile, in memoria

# Create a simple functional model
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

# Define a subclassed model with the same architecture
class SubclassedModel(keras.Model):
    def __init__(self, output_dim, name=None):
        super(SubclassedModel, self).__init__(name=name)
        self.output_dim = output_dim
        self.dense_1 = keras.layers.Dense(64, activation="relu", name="dense_1")
        self.dense_2 = keras.layers.Dense(64, activation="relu", name="dense_2")
        self.dense_3 = keras.layers.Dense(output_dim, name="predictions")

    def call(self, inputs):
        x = self.dense_1(inputs)
        x = self.dense_2(x)
        x = self.dense_3(x)
        return x

    def get_config(self):
        return {"output_dim": self.output_dim, "name": self.name}


subclassed_model = SubclassedModel(10)
# Call the subclassed model once to create the weights.
subclassed_model(tf.ones((1, 784)))

# Copy weights from functional_model to subclassed_model.
subclassed_model.set_weights(functional_model.get_weights())

assert len(functional_model.weights) == len(subclassed_model.weights)
for a, b in zip(functional_model.weights, subclassed_model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

Il caso degli apolidi

Poiché i livelli senza stato non cambiano l'ordine o il numero di pesi, i modelli possono avere architetture compatibili anche se sono presenti livelli senza stato aggiuntivi / mancanti.

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)

# Add a dropout layer, which does not contain any weights.
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model_with_dropout = keras.Model(
    inputs=inputs, outputs=outputs, name="3_layer_mlp"
)

functional_model_with_dropout.set_weights(functional_model.get_weights())

API per salvare i pesi su disco e caricarli nuovamente

I pesi possono essere salvati su disco chiamando model.save_weights nei seguenti formati:

  • TensorFlow Checkpoint
  • HDF5

Il formato predefinito per model.save_weights è il checkpoint TensorFlow. Esistono due modi per specificare il formato di salvataggio:

  1. argomento save_format : save_format il valore su save_format="tf" o save_format="h5" .
  2. argomento del path : se il percorso termina con .h5 o .hdf5 , viene utilizzato il formato HDF5. Altri suffissi risulteranno in un checkpoint TensorFlow a meno che non sia impostato save_format .

C'è anche un'opzione per recuperare i pesi come array numpy in memoria. Ogni API ha i suoi pro e contro che sono descritti di seguito.

Formato Checkpoint TF

Esempio:

# Runnable example
sequential_model = keras.Sequential(
    [
        keras.Input(shape=(784,), name="digits"),
        keras.layers.Dense(64, activation="relu", name="dense_1"),
        keras.layers.Dense(64, activation="relu", name="dense_2"),
        keras.layers.Dense(10, name="predictions"),
    ]
)
sequential_model.save_weights("ckpt")
load_status = sequential_model.load_weights("ckpt")

# `assert_consumed` can be used as validation that all variable values have been
# restored from the checkpoint. See `tf.train.Checkpoint.restore` for other
# methods in the Status object.
load_status.assert_consumed()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd0a065f128>

Dettagli del formato

Il formato TensorFlow Checkpoint salva e ripristina i pesi utilizzando i nomi degli attributi degli oggetti. Ad esempio, considera il livello tf.keras.layers.Dense . Il livello contiene due pesi: dense.kernel e dense.bias . Quando il livello viene salvato nel formato tf , il checkpoint risultante contiene le chiavi "kernel" e "bias" e i valori di peso corrispondenti. Per ulteriori informazioni vedere "Meccaniche di caricamento" nella guida TF Checkpoint .

Notare che l'attributo / il bordo del grafico prende il nome dal nome utilizzato nell'oggetto genitore, non dal nome della variabile . Considera il CustomLayer nell'esempio seguente. La variabile CustomLayer.var viene salvata con "var" come parte della chiave, non "var_a" .

class CustomLayer(keras.layers.Layer):
    def __init__(self, a):
        self.var = tf.Variable(a, name="var_a")


layer = CustomLayer(5)
layer_ckpt = tf.train.Checkpoint(layer=layer).save("custom_layer")

ckpt_reader = tf.train.load_checkpoint(layer_ckpt)

ckpt_reader.get_variable_to_dtype_map()
{'save_counter/.ATTRIBUTES/VARIABLE_VALUE': tf.int64,
 '_CHECKPOINTABLE_OBJECT_GRAPH': tf.string,
 'layer/var/.ATTRIBUTES/VARIABLE_VALUE': tf.int32}

Trasferimento di esempio di apprendimento

In sostanza, fintanto che due modelli hanno la stessa architettura, sono in grado di condividere lo stesso checkpoint.

Esempio:

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")

# Extract a portion of the functional model defined in the Setup section.
# The following lines produce a new model that excludes the final output
# layer of the functional model.
pretrained = keras.Model(
    functional_model.inputs, functional_model.layers[-1].input, name="pretrained_model"
)
# Randomly assign "trained" weights.
for w in pretrained.weights:
    w.assign(tf.random.normal(w.shape))
pretrained.save_weights("pretrained_ckpt")
pretrained.summary()

# Assume this is a separate program where only 'pretrained_ckpt' exists.
# Create a new functional model with a different output dimension.
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = keras.layers.Dense(5, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="new_model")

# Load the weights from pretrained_ckpt into model.
model.load_weights("pretrained_ckpt")

# Check that all of the pretrained weights have been loaded.
for a, b in zip(pretrained.weights, model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

print("\n", "-" * 50)
model.summary()

# Example 2: Sequential model
# Recreate the pretrained model, and load the saved weights.
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
pretrained_model = keras.Model(inputs=inputs, outputs=x, name="pretrained")

# Sequential example:
model = keras.Sequential([pretrained_model, keras.layers.Dense(5, name="predictions")])
model.summary()

pretrained_model.load_weights("pretrained_ckpt")

# Warning! Calling `model.load_weights('pretrained_ckpt')` won't throw an error,
# but will *not* work as expected. If you inspect the weights, you'll see that
# none of the weights will have loaded. `pretrained_model.load_weights()` is the
# correct method to call.
Model: "pretrained_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
=================================================================
Total params: 54,400
Trainable params: 54,400
Non-trainable params: 0
_________________________________________________________________

 --------------------------------------------------
Model: "new_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
predictions (Dense)          (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
pretrained (Functional)      (None, 64)                54400     
_________________________________________________________________
predictions (Dense)          (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd144b20b38>

In genere è consigliabile attenersi alla stessa API per la creazione di modelli. Se si passa da sequenziale a funzionale o funzionale e sottoclasse, ecc., Ricostruire sempre il modello pre-addestrato e caricare i pesi pre-addestrati su quel modello.

La domanda successiva è: come possono essere salvati i pesi e caricati su modelli diversi se le architetture del modello sono abbastanza diverse? La soluzione è utilizzare tf.train.Checkpoint per salvare e ripristinare i layer / le variabili esatti.

Esempio:

# Create a subclassed model that essentially uses functional_model's first
# and last layers.
# First, save the weights of functional_model's first and last dense layers.
first_dense = functional_model.layers[1]
last_dense = functional_model.layers[-1]
ckpt_path = tf.train.Checkpoint(
    dense=first_dense, kernel=last_dense.kernel, bias=last_dense.bias
).save("ckpt")

# Define the subclassed model.
class ContrivedModel(keras.Model):
    def __init__(self):
        super(ContrivedModel, self).__init__()
        self.first_dense = keras.layers.Dense(64)
        self.kernel = self.add_variable("kernel", shape=(64, 10))
        self.bias = self.add_variable("bias", shape=(10,))

    def call(self, inputs):
        x = self.first_dense(inputs)
        return tf.matmul(x, self.kernel) + self.bias


model = ContrivedModel()
# Call model on inputs to create the variables of the dense layer.
_ = model(tf.ones((1, 784)))

# Create a Checkpoint with the same structure as before, and load the weights.
tf.train.Checkpoint(
    dense=model.first_dense, kernel=model.kernel, bias=model.bias
).restore(ckpt_path).assert_consumed()
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:2281: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd1455c6cc0>

Formato HDF5

Il formato HDF5 contiene pesi raggruppati in base ai nomi dei livelli. I pesi sono elenchi ordinati concatenando l'elenco dei pesi addestrabili all'elenco dei pesi non layer.weights (lo stesso di layer.weights ). Pertanto, un modello può utilizzare un checkpoint hdf5 se ha gli stessi layer e stati addestrabili salvati nel checkpoint.

Esempio:

# Runnable example
sequential_model = keras.Sequential(
    [
        keras.Input(shape=(784,), name="digits"),
        keras.layers.Dense(64, activation="relu", name="dense_1"),
        keras.layers.Dense(64, activation="relu", name="dense_2"),
        keras.layers.Dense(10, name="predictions"),
    ]
)
sequential_model.save_weights("weights.h5")
sequential_model.load_weights("weights.h5")

Si noti che la modifica di layer.trainable può comportare un diverso ordine layer.weights quando il modello contiene layer nidificati.

class NestedDenseLayer(keras.layers.Layer):
    def __init__(self, units, name=None):
        super(NestedDenseLayer, self).__init__(name=name)
        self.dense_1 = keras.layers.Dense(units, name="dense_1")
        self.dense_2 = keras.layers.Dense(units, name="dense_2")

    def call(self, inputs):
        return self.dense_2(self.dense_1(inputs))


nested_model = keras.Sequential([keras.Input((784,)), NestedDenseLayer(10, "nested")])
variable_names = [v.name for v in nested_model.weights]
print("variables: {}".format(variable_names))

print("\nChanging trainable status of one of the nested layers...")
nested_model.get_layer("nested").dense_1.trainable = False

variable_names_2 = [v.name for v in nested_model.weights]
print("\nvariables: {}".format(variable_names_2))
print("variable ordering changed:", variable_names != variable_names_2)
variables: ['nested/dense_1/kernel:0', 'nested/dense_1/bias:0', 'nested/dense_2/kernel:0', 'nested/dense_2/bias:0']

Changing trainable status of one of the nested layers...

variables: ['nested/dense_2/kernel:0', 'nested/dense_2/bias:0', 'nested/dense_1/kernel:0', 'nested/dense_1/bias:0']
variable ordering changed: True

Trasferimento di esempio di apprendimento

Quando si caricano pesi pre-addestrati da HDF5, si consiglia di caricare i pesi nel modello con punto di controllo originale, quindi estrarre i pesi / strati desiderati in un nuovo modello.

Esempio:

def create_functional_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
    x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
    outputs = keras.layers.Dense(10, name="predictions")(x)
    return keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")


functional_model = create_functional_model()
functional_model.save_weights("pretrained_weights.h5")

# In a separate program:
pretrained_model = create_functional_model()
pretrained_model.load_weights("pretrained_weights.h5")

# Create a new model by extracting layers from the original model:
extracted_layers = pretrained_model.layers[:-1]
extracted_layers.append(keras.layers.Dense(5, name="dense_3"))
model = keras.Sequential(extracted_layers)
model.summary()
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3 (Dense)              (None, 5)                 325       
=================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
_________________________________________________________________