Cette page a été traduite par l'API Cloud Translation.
Switch to English

Points de contrôle de la formation

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le cahier

L'expression "Enregistrement d'un modèle TensorFlow" signifie généralement l'une des deux choses suivantes:

  1. Points de contrôle, OU
  2. SavedModel.

Les points de contrôle capturent la valeur exacte de tous les paramètres (objets tf.Variable ) utilisés par un modèle. Les points de contrôle ne contiennent aucune description du calcul défini par le modèle et ne sont donc généralement utiles que lorsque le code source qui utilisera les valeurs de paramètres enregistrées est disponible.

Le format SavedModel, quant à lui, comprend une description sérialisée du calcul défini par le modèle en plus des valeurs des paramètres (point de contrôle). Les modèles dans ce format sont indépendants du code source qui a créé le modèle. Ils sont ainsi adaptés au déploiement via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, ou des programmes dans d'autres langages de programmation (les API TensorFlow C, C ++, Java, Go, Rust, C # etc.).

Ce guide couvre les API pour l'écriture et la lecture des points de contrôle.

Installer

 import tensorflow as tf
 
 class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
 
 net = Net()
 

Enregistrement à partir des API de formation tf.keras

Consultez le guide tf.keras sur l'enregistrement et la restauration.

tf.keras.Model.save_weights enregistre un point de contrôle TensorFlow.

 net.save_weights('easy_checkpoint')
 

Écriture des points de contrôle

L'état persistant d'un modèle TensorFlow est stocké dans les objets tf.Variable . Ceux-ci peuvent être construits directement, mais sont souvent créés via des API de haut niveau comme tf.keras.layers ou tf.keras.Model .

Le moyen le plus simple de gérer les variables consiste à les attacher à des objets Python, puis à référencer ces objets.

Les sous-classes de tf.train.Checkpoint , tf.keras.layers.Layer et tf.keras.Model automatiquement les variables affectées à leurs attributs. L'exemple suivant construit un modèle linéaire simple, puis écrit des points de contrôle qui contiennent des valeurs pour toutes les variables du modèle.

Vous pouvez facilement enregistrer un point de contrôle de modèle avec Model.save_weights

Point de contrôle manuel

Installer

Pour aider à démontrer toutes les fonctionnalités de tf.train.Checkpoint définissez un jeu de données de jouet et une étape d'optimisation:

 def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
 
 def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss
 

Créer les objets de point de contrôle

Pour créer manuellement un point de contrôle, vous aurez besoin d'un objet tf.train.Checkpoint . Où les objets que vous souhaitez contrôler sont définis en tant qu'attributs sur l'objet.

Un tf.train.CheckpointManager peut également être utile pour gérer plusieurs points de contrôle.

 opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
 

Former et contrôler le modèle

La boucle d'apprentissage suivante crée une instance du modèle et d'un optimiseur, puis les rassemble dans un objet tf.train.Checkpoint . Il appelle l'étape d'entraînement en boucle sur chaque lot de données et écrit périodiquement des points de contrôle sur le disque.

 def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
 
 train_and_checkpoint(net, manager)
 
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 28.06
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 21.47
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 14.93
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 8.50
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.70

Restaurer et continuer la formation

Après le premier, vous pouvez passer un nouveau modèle et un nouveau gestionnaire, mais reprendre la formation exactement là où vous vous êtes arrêté:

 opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
 
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.05
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.85
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.49
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.33
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.15

L'objet tf.train.CheckpointManager supprime les anciens points de contrôle. Au-dessus, il est configuré pour ne conserver que les trois points de contrôle les plus récents.

 print(manager.checkpoints)  # List the three remaining checkpoints
 
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Ces chemins, par exemple './tf_ckpts/ckpt-10' , ne sont pas des fichiers sur le disque. Au lieu de cela, ce sont des préfixes pour un fichier d' index et un ou plusieurs fichiers de données contenant les valeurs de variable. Ces préfixes sont regroupés dans un seul fichier de checkpoint ( './tf_ckpts/checkpoint' ) dans lequel CheckpointManager enregistre son état.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00001-of-00002
ckpt-10.data-00000-of-00002  ckpt-8.index
ckpt-10.data-00001-of-00002  ckpt-9.data-00000-of-00002
ckpt-10.index            ckpt-9.data-00001-of-00002
ckpt-8.data-00000-of-00002   ckpt-9.index

Mécanique de chargement

TensorFlow fait correspondre les variables aux valeurs de points de contrôle en parcourant un graphe orienté avec des arêtes nommées, en commençant par l'objet en cours de chargement. Les noms d'arêtes proviennent généralement des noms d'attributs dans les objets, par exemple le "l1" dans self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint utilise ses noms d'argument de mot-clé, comme dans "step" dans tf.train.Checkpoint(step=...) .

Le graphe de dépendances de l'exemple ci-dessus ressemble à ceci:

Visualisation du graphe de dépendances pour l'exemple de boucle d'apprentissage

Avec l'optimiseur en rouge, les variables régulières en bleu et les variables d'emplacement de l'optimiseur en orange. Les autres nœuds, par exemple représentant le tf.train.Checkpoint , sont noirs.

Les variables d'emplacement font partie de l'état de l'optimiseur, mais sont créées pour une variable spécifique. Par exemple, les arêtes 'm' ci-dessus correspondent à l'élan, que l'optimiseur Adam suit pour chaque variable. Les variables d'emplacement ne sont enregistrées dans un point de contrôle que si la variable et l'optimiseur sont tous deux enregistrés, donc les bords en pointillés.

L'appel de restore() sur un objet tf.train.Checkpoint file d'attente les restaurations demandées, en restaurant les valeurs des variables dès qu'il existe un chemin correspondant à partir de l'objet Checkpoint . Par exemple, nous pouvons charger uniquement le biais du modèle que nous avons défini ci-dessus en reconstruisant un chemin vers celui-ci à travers le réseau et la couche.

 to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
 
[0. 0. 0. 0. 0.]
[1.5749372 3.2779367 2.3600516 4.620399  3.9607847]

Le graphe de dépendance pour ces nouveaux objets est un sous-graphe beaucoup plus petit du plus grand point de contrôle que nous avons écrit ci-dessus. Il comprend uniquement le biais et un compteur de sauvegarde que tf.train.Checkpoint utilise pour numéroter les points de contrôle.

Visualisation d'un sous-graphe pour la variable de biais

restore() renvoie un objet status, qui a des assertions facultatives. Tous les objets que nous avons créés dans notre nouveau Checkpoint ont été restaurés, donc status.assert_existing_objects_matched() passe.

 status.assert_existing_objects_matched()
 
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fe44ad109b0>

Il y a de nombreux objets dans le point de contrôle qui ne correspondent pas, y compris le noyau de la couche et les variables de l'optimiseur. status.assert_consumed() ne passe que si le point de contrôle et le programme correspondent exactement, et lèverait une exception ici.

Restaurations retardées

Layer objets de Layer dans TensorFlow peuvent retarder la création des variables jusqu'à leur premier appel, lorsque des formes d'entrée sont disponibles. Par exemple, la forme du noyau d'un calque Dense dépend à la fois des formes d'entrée et de sortie du calque, et donc la forme de sortie requise en tant qu'argument de constructeur n'est pas assez d'informations pour créer la variable seule. Puisque l'appel d'un Layer lit également la valeur de la variable, une restauration doit avoir lieu entre la création de la variable et sa première utilisation.

Pour prendre en charge cet idiome, tf.train.Checkpoint restaure les files d'attente qui n'ont pas encore de variable correspondante.

 delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
 
[[0. 0. 0. 0. 0.]]
[[4.733721  4.6473064 4.9289813 4.7712607 4.958914 ]]

Inspection manuelle des points de contrôle

tf.train.list_variables répertorie les clés de point de contrôle et les formes des variables dans un point de contrôle. Les clés de point de contrôle sont des chemins dans le graphique affiché ci-dessus.

 tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
 
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

Suivi des listes et des dictionnaires

Comme pour les attributions directes d'attributs comme self.l1 = tf.keras.layers.Dense(5) , l'affectation de listes et de dictionnaires à des attributs suivra leur contenu.

 save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
 

Vous remarquerez peut-être des objets wrapper pour les listes et les dictionnaires. Ces wrappers sont des versions vérifiables des structures de données sous-jacentes. Tout comme le chargement basé sur les attributs, ces wrappers restaurent la valeur d'une variable dès qu'elle est ajoutée au conteneur.

 restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
 
ListWrapper([])

Le même suivi est automatiquement appliqué aux sous-classes de tf.keras.Model , et peut être utilisé par exemple pour suivre des listes de couches.

Enregistrement de points de contrôle basés sur des objets avec Estimator

Consultez le guide Estimator .

Les estimateurs enregistrent par défaut les points de contrôle avec des noms de variables plutôt que le graphe d'objets décrit dans les sections précédentes. tf.train.Checkpoint acceptera les points de contrôle basés sur les noms, mais les noms de variables peuvent changer lors du déplacement de parties d'un modèle en dehors de model_fn de l'estimateur. L'enregistrement de points de contrôle basés sur des objets facilite l'apprentissage d'un modèle dans un Estimator, puis son utilisation en dehors de celui-ci.

 import tensorflow.compat.v1 as tf_compat
 
 def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
 
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1666: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.5144663, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 37.124985.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fe44a8d10b8>

tf.train.Checkpoint peut alors charger les points de contrôle de l'estimateur à partir de son model_dir .

 opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
 
10

Résumé

Les objets TensorFlow fournissent un mécanisme automatique simple pour enregistrer et restaurer les valeurs des variables qu'ils utilisent.