Cette page a été traduite par l'API Cloud Translation.
Switch to English

Enregistrer et charger des modèles

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le carnet

La progression du modèle peut être enregistrée pendant et après l'entraînement. Cela signifie qu'un modèle peut reprendre là où il s'était arrêté et éviter de longues périodes d'entraînement. L'enregistrement signifie également que vous pouvez partager votre modèle et que d'autres peuvent recréer votre travail. Lors de la publication de modèles et de techniques de recherche, la plupart des praticiens de l'apprentissage automatique partagent:

  • code pour créer le modèle, et
  • les poids entraînés, ou paramètres, pour le modèle

Le partage de ces données aide les autres à comprendre le fonctionnement du modèle et à l'essayer eux-mêmes avec de nouvelles données.

Options

Il existe différentes manières d'enregistrer les modèles TensorFlow, selon l'API que vous utilisez. Ce guide utilise tf.keras , une API de haut niveau pour créer et entraîner des modèles dans TensorFlow. Pour d'autres approches, reportez-vous au guide TensorFlow Save and Restore ou Enregistrement en hâte .

Installer

Installe et importe

Installez et importez TensorFlow et ses dépendances:

pip install -q pyyaml h5py  # Required to save models in HDF5 format
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)
2.3.0

Obtenir un exemple de jeu de données

Pour montrer comment enregistrer et charger des poids, vous utiliserez le jeu de données MNIST . Pour accélérer ces exécutions, utilisez les 1000 premiers exemples:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

Définir un modèle

Commencez par créer un modèle séquentiel simple:

# Define a simple sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Enregistrer les points de contrôle pendant la formation

Vous pouvez utiliser un modèle entraîné sans avoir à le recycler ou reprendre la formation là où vous vous étiez arrêté, au cas où le processus de formation serait interrompu. Le callback tf.keras.callbacks.ModelCheckpoint permet de sauvegarder en permanence le modèle pendant et à la fin de la formation.

Utilisation du rappel de point de contrôle

Créez un rappel tf.keras.callbacks.ModelCheckpoint qui enregistre les poids uniquement pendant l'entraînement:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images,test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10
27/32 [========================>.....] - ETA: 0s - loss: 1.2283 - sparse_categorical_accuracy: 0.6574
Epoch 00001: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 8ms/step - loss: 1.1420 - sparse_categorical_accuracy: 0.6810 - val_loss: 0.7102 - val_sparse_categorical_accuracy: 0.7740
Epoch 2/10
28/32 [=========================>....] - ETA: 0s - loss: 0.4317 - sparse_categorical_accuracy: 0.8795
Epoch 00002: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.4257 - sparse_categorical_accuracy: 0.8780 - val_loss: 0.5635 - val_sparse_categorical_accuracy: 0.8300
Epoch 3/10
28/32 [=========================>....] - ETA: 0s - loss: 0.2886 - sparse_categorical_accuracy: 0.9241
Epoch 00003: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.2908 - sparse_categorical_accuracy: 0.9210 - val_loss: 0.4967 - val_sparse_categorical_accuracy: 0.8370
Epoch 4/10
28/32 [=========================>....] - ETA: 0s - loss: 0.2128 - sparse_categorical_accuracy: 0.9453
Epoch 00004: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.2104 - sparse_categorical_accuracy: 0.9460 - val_loss: 0.4355 - val_sparse_categorical_accuracy: 0.8580
Epoch 5/10
28/32 [=========================>....] - ETA: 0s - loss: 0.1396 - sparse_categorical_accuracy: 0.9654
Epoch 00005: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.1474 - sparse_categorical_accuracy: 0.9650 - val_loss: 0.4435 - val_sparse_categorical_accuracy: 0.8530
Epoch 6/10
29/32 [==========================>...] - ETA: 0s - loss: 0.1060 - sparse_categorical_accuracy: 0.9860
Epoch 00006: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.1096 - sparse_categorical_accuracy: 0.9840 - val_loss: 0.4177 - val_sparse_categorical_accuracy: 0.8650
Epoch 7/10
28/32 [=========================>....] - ETA: 0s - loss: 0.0966 - sparse_categorical_accuracy: 0.9844
Epoch 00007: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.0932 - sparse_categorical_accuracy: 0.9850 - val_loss: 0.4236 - val_sparse_categorical_accuracy: 0.8610
Epoch 8/10
27/32 [========================>.....] - ETA: 0s - loss: 0.0644 - sparse_categorical_accuracy: 0.9931
Epoch 00008: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.0658 - sparse_categorical_accuracy: 0.9920 - val_loss: 0.4141 - val_sparse_categorical_accuracy: 0.8680
Epoch 9/10
27/32 [========================>.....] - ETA: 0s - loss: 0.0478 - sparse_categorical_accuracy: 1.0000
Epoch 00009: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0499 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.4156 - val_sparse_categorical_accuracy: 0.8680
Epoch 10/10
27/32 [========================>.....] - ETA: 0s - loss: 0.0369 - sparse_categorical_accuracy: 0.9988
Epoch 00010: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0368 - sparse_categorical_accuracy: 0.9990 - val_loss: 0.4239 - val_sparse_categorical_accuracy: 0.8680

<tensorflow.python.keras.callbacks.History at 0x7fef1c745c18>

Cela crée une seule collection de fichiers de point de contrôle TensorFlow qui sont mis à jour à la fin de chaque époque:

ls {checkpoint_dir}
checkpoint  cp.ckpt.data-00000-of-00001  cp.ckpt.index

Créez un nouveau modèle non entraîné. Lors de la restauration d'un modèle à partir de poids uniquement, vous devez disposer d'un modèle avec la même architecture que le modèle d'origine. Comme il s'agit de la même architecture de modèle, vous pouvez partager des pondérations même s'il s'agit d'une instance différente du modèle.

Maintenant, reconstruisez un nouveau modèle non entraîné et évaluez-le sur l'ensemble de test. Un modèle non entraîné fonctionnera à des niveaux aléatoires (précision d'environ 10%):

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
32/32 - 0s - loss: 2.3319 - sparse_categorical_accuracy: 0.1410
Untrained model, accuracy: 14.10%

Ensuite, chargez les poids du point de contrôle et réévaluez:

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss,acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
32/32 - 0s - loss: 0.4239 - sparse_categorical_accuracy: 0.8680
Restored model, accuracy: 86.80%

Options de rappel de point de contrôle

Le rappel fournit plusieurs options pour fournir des noms uniques pour les points de contrôle et ajuster la fréquence des points de contrôle.

Entraînez un nouveau modèle et enregistrez des points de contrôle avec un nom unique une fois toutes les cinq époques:

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    period=5)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          callbacks=[cp_callback],
          validation_data=(test_images,test_labels),
          verbose=0)
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

Epoch 00005: saving model to training_2/cp-0005.ckpt

Epoch 00010: saving model to training_2/cp-0010.ckpt

Epoch 00015: saving model to training_2/cp-0015.ckpt

Epoch 00020: saving model to training_2/cp-0020.ckpt

Epoch 00025: saving model to training_2/cp-0025.ckpt

Epoch 00030: saving model to training_2/cp-0030.ckpt

Epoch 00035: saving model to training_2/cp-0035.ckpt

Epoch 00040: saving model to training_2/cp-0040.ckpt

Epoch 00045: saving model to training_2/cp-0045.ckpt

Epoch 00050: saving model to training_2/cp-0050.ckpt

<tensorflow.python.keras.callbacks.History at 0x7fef6ff62898>

Maintenant, regardez les points de contrôle résultants et choisissez le dernier:

ls {checkpoint_dir}
checkpoint            cp-0025.ckpt.index
cp-0000.ckpt.data-00000-of-00001  cp-0030.ckpt.data-00000-of-00001
cp-0000.ckpt.index        cp-0030.ckpt.index
cp-0005.ckpt.data-00000-of-00001  cp-0035.ckpt.data-00000-of-00001
cp-0005.ckpt.index        cp-0035.ckpt.index
cp-0010.ckpt.data-00000-of-00001  cp-0040.ckpt.data-00000-of-00001
cp-0010.ckpt.index        cp-0040.ckpt.index
cp-0015.ckpt.data-00000-of-00001  cp-0045.ckpt.data-00000-of-00001
cp-0015.ckpt.index        cp-0045.ckpt.index
cp-0020.ckpt.data-00000-of-00001  cp-0050.ckpt.data-00000-of-00001
cp-0020.ckpt.index        cp-0050.ckpt.index
cp-0025.ckpt.data-00000-of-00001

latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

Pour tester, réinitialisez le modèle et chargez le dernier point de contrôle:

# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
32/32 - 0s - loss: 0.5075 - sparse_categorical_accuracy: 0.8730
Restored model, accuracy: 87.30%

Quels sont ces fichiers?

Le code ci-dessus stocke les pondérations dans une collection de fichiers au format de point de contrôle qui contiennent uniquement les pondérations entraînées dans un format binaire. Les points de contrôle contiennent:

  • Un ou plusieurs fragments contenant les pondérations de votre modèle.
  • Un fichier d'index qui indique quels poids sont stockés dans quel fragment.

Si vous .data-00000-of-00001 un modèle que sur une seule machine, vous aurez un fragment avec le suffixe: .data-00000-of-00001

Enregistrer manuellement les poids

Vous avez vu comment charger les poids dans un modèle. Les enregistrer manuellement est tout aussi simple avec la méthode Model.save_weights . Par défaut, tf.keras - et save_weights en particulier - utilise le format de point de contrôle TensorFlow avec une extension .ckpt (l'enregistrement en HDF5 avec une extension .h5 est traité dans le guide Enregistrer et sérialiser les modèles ):

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss,acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
32/32 - 0s - loss: 0.5075 - sparse_categorical_accuracy: 0.8730
Restored model, accuracy: 87.30%

Enregistrez le modèle entier

Appelez model.save pour enregistrer l'architecture, les poids et la configuration d'entraînement d'un modèle dans un seul fichier / dossier. Cela vous permet d'exporter un modèle afin qu'il puisse être utilisé sans accéder au code Python d'origine *. Étant donné que l'état de l'optimiseur est récupéré, vous pouvez reprendre l'entraînement exactement là où vous l'avez laissé.

Le modèle entier peut être enregistré dans deux formats de fichier différents ( SavedModel et HDF5 ). Il est à noter que le format TensorFlow SavedModel est le format de fichier par défaut dans TF2.x. Cependant, le modèle peut être enregistré au format HDF5 . Plus de détails sur l'enregistrement du modèle entier dans les deux formats de fichier sont décrits ci-dessous.

L'enregistrement d'un modèle entièrement fonctionnel est très utile: vous pouvez les charger dans TensorFlow.js ( modèle enregistré , HDF5 ), puis les entraîner et les exécuter dans des navigateurs Web, ou les convertir pour qu'ils s'exécutent sur des appareils mobiles à l'aide de TensorFlow Lite ( modèle enregistré , HDF5) )

* Les objets personnalisés (par exemple les modèles ou calques sous-classés) nécessitent une attention particulière lors de la sauvegarde et du chargement. Voir la section Enregistrer des objets personnalisés ci-dessous

Format du modèle enregistré

Le format SavedModel est une autre façon de sérialiser des modèles. Les modèles enregistrés dans ce format peuvent être restaurés à l'aide de tf.keras.models.load_model et sont compatibles avec TensorFlow Serving. Le guide SavedModel explique en détail comment servir / inspecter le SavedModel. La section ci-dessous illustre les étapes d'enregistrement et de restauration du modèle.

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model') 
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1973 - sparse_categorical_accuracy: 0.6330
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4247 - sparse_categorical_accuracy: 0.8740
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2784 - sparse_categorical_accuracy: 0.9380
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2029 - sparse_categorical_accuracy: 0.9490
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1441 - sparse_categorical_accuracy: 0.9730
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model/my_model/assets

Le format SavedModel est un répertoire contenant un binaire protobuf et un point de contrôle Tensorflow. Inspectez le répertoire du modèle enregistré:

# my_model directory
!ls saved_model

# Contains an assets folder, saved_model.pb, and variables folder.
!ls saved_model/my_model
my_model
assets  saved_model.pb  variables

Rechargez un nouveau modèle Keras à partir du modèle enregistré:

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_10 (Dense)             (None, 512)               401920    
_________________________________________________________________
dropout_5 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Le modèle restauré est compilé avec les mêmes arguments que le modèle d'origine. Essayez d'exécuter évaluer et prédire avec le modèle chargé:

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images,  test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100*acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4378 - sparse_categorical_accuracy: 0.8570
Restored model, accuracy: 85.70%
(1000, 10)

Format HDF5

Keras fournit un format d'enregistrement de base utilisant la norme HDF5 .

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5') 
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1785 - sparse_categorical_accuracy: 0.6550
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4266 - sparse_categorical_accuracy: 0.8740
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2987 - sparse_categorical_accuracy: 0.9170
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2228 - sparse_categorical_accuracy: 0.9400
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1612 - sparse_categorical_accuracy: 0.9660

Maintenant, recréez le modèle à partir de ce fichier:

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_12 (Dense)             (None, 512)               401920    
_________________________________________________________________
dropout_6 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_13 (Dense)             (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Vérifiez sa précision:

loss, acc = new_model.evaluate(test_images,  test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100*acc))
32/32 - 0s - loss: 0.4458 - sparse_categorical_accuracy: 0.8510
Restored model, accuracy: 85.10%

Keras enregistre les modèles en inspectant l'architecture. Cette technique sauve tout:

  • Les valeurs de poids
  • L'architecture du modèle
  • La configuration de la formation du modèle (ce que vous avez passé à la compilation)
  • L'optimiseur et son état, le cas échéant (cela vous permet de redémarrer l'entraînement là où vous l'avez laissé)

Keras ne peut pas enregistrer les optimiseurs v1.x (à partir de tf.compat.v1.train ) car ils ne sont pas compatibles avec les points de contrôle. Pour les optimiseurs v1.x, vous devez recompiler le modèle après le chargement, en perdant l'état de l'optimiseur.

Enregistrer des objets personnalisés

Si vous utilisez le format SavedModel, vous pouvez ignorer cette section. La principale différence entre HDF5 et SavedModel est que HDF5 utilise des configurations d'objet pour enregistrer l'architecture du modèle, tandis que SavedModel enregistre le graphique d'exécution. Ainsi, SavedModels peut enregistrer des objets personnalisés tels que des modèles sous-classés et des couches personnalisées sans avoir besoin du code original.

Pour enregistrer des objets personnalisés dans HDF5, vous devez effectuer les opérations suivantes:

  1. Définissez une méthode get_config dans votre objet, et éventuellement une from_config classe from_config.
    • get_config(self) renvoie un dictionnaire sérialisable JSON des paramètres nécessaires pour recréer l'objet.
    • from_config(cls, config) utilise la configuration renvoyée par get_config pour créer un nouvel objet. Par défaut, cette fonction utilisera la configuration comme kwargs d'initialisation ( return cls(**config) ).
  2. Passez l'objet à l'argument custom_objects lors du chargement du modèle. L'argument doit être un dictionnaire mappant le nom de la classe de chaîne à la classe Python. Par exemple, tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})

Consultez le didacticiel sur l' écriture de calques et de modèles à partir de zéro pour des exemples d'objets personnalisés et de get_config .


#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.