Points de contrôle de la formation

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le cahier

L'expression "Enregistrement d'un modèle TensorFlow" signifie généralement l'une des deux choses suivantes:

  1. Points de contrôle, OU
  2. SavedModel.

Les points de contrôle capturent la valeur exacte de tous les paramètres (objets tf.Variable ) utilisés par un modèle. Les points de contrôle ne contiennent aucune description du calcul défini par le modèle et ne sont donc généralement utiles que lorsque le code source qui utilisera les valeurs de paramètres enregistrées est disponible.

Le format SavedModel, quant à lui, comprend une description sérialisée du calcul défini par le modèle en plus des valeurs des paramètres (point de contrôle). Les modèles dans ce format sont indépendants du code source qui a créé le modèle. Ils sont donc adaptés au déploiement via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, ou des programmes dans d'autres langages de programmation (les API TensorFlow C, C ++, Java, Go, Rust, C # etc.).

Ce guide couvre les API pour l'écriture et la lecture des points de contrôle.

Installer

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Enregistrement à partir des API de formation tf.keras

Consultez le guide tf.keras sur l'enregistrement et la restauration.

tf.keras.Model.save_weights enregistre un point de contrôle TensorFlow.

net.save_weights('easy_checkpoint')

Écriture des points de contrôle

L'état persistant d'un modèle TensorFlow est stocké dans les objets tf.Variable . Ceux-ci peuvent être construits directement, mais sont souvent créés via des API de haut niveau commetf.keras.layers ou tf.keras.Model .

Le moyen le plus simple de gérer les variables consiste à les attacher à des objets Python, puis à référencer ces objets.

Les sous-classes de tf.train.Checkpoint , tf.keras.layers.Layer et tf.keras.Model automatiquement les variables affectées à leurs attributs. L'exemple suivant construit un modèle linéaire simple, puis écrit des points de contrôle qui contiennent des valeurs pour toutes les variables du modèle.

Vous pouvez facilement enregistrer un modèle-point de contrôle avec Model.save_weights .

Point de contrôle manuel

Installer

Pour aider à démontrer toutes les fonctionnalités de tf.train.Checkpoint , définissez un jeu de données jouet et une étape d'optimisation:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Créer les objets de point de contrôle

Utilisez un objet tf.train.Checkpoint pour créer manuellement un point de contrôle, où les objets que vous souhaitez contrôler sont définis comme attributs sur l'objet.

Un tf.train.CheckpointManager peut également être utile pour gérer plusieurs points de contrôle.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Former et contrôler le modèle

La boucle d'apprentissage suivante crée une instance du modèle et d'un optimiseur, puis les rassemble dans un objet tf.train.Checkpoint . Il appelle l'étape d'entraînement en boucle sur chaque lot de données et écrit périodiquement des points de contrôle sur le disque.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.00
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 22.42
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.86
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 9.40
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.20

Restaurer et continuer la formation

Après le premier cycle de formation, vous pouvez passer un nouveau modèle et un nouveau manager, mais reprendre la formation exactement là où vous vous étiez arrêté:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.19
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.66
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.90
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.32
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.34

L'objet tf.train.CheckpointManager supprime les anciens points de contrôle. Au-dessus, il est configuré pour ne conserver que les trois points de contrôle les plus récents.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Ces chemins, par exemple './tf_ckpts/ckpt-10' , ne sont pas des fichiers sur le disque. Au lieu de cela, ce sont des préfixes pour un fichier d' index et un ou plusieurs fichiers de données qui contiennent les valeurs de variable. Ces préfixes sont regroupés dans un seul fichier de checkpoint ( './tf_ckpts/checkpoint' ) où le CheckpointManager enregistre son état.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mécanique de chargement

TensorFlow fait correspondre les variables aux valeurs de points de contrôle en parcourant un graphe orienté avec des arêtes nommées, en commençant par l'objet en cours de chargement. Les noms d'arêtes proviennent généralement des noms d'attributs dans les objets, par exemple "l1" dans self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint utilise ses noms d'argument de mot-clé, comme dans "step" dans tf.train.Checkpoint(step=...) .

Le graphe de dépendances de l'exemple ci-dessus ressemble à ceci:

Visualisation du graphe de dépendances pour l'exemple de boucle d'apprentissage

L'optimiseur est en rouge, les variables régulières sont en bleu et les variables d'emplacement de l'optimiseur sont en orange. Les autres nœuds - par exemple, représentant le tf.train.Checkpoint - sont en noir.

Les variables d'emplacement font partie de l'état de l'optimiseur, mais sont créées pour une variable spécifique. Par exemple, les arêtes 'm' ci-dessus correspondent à l'élan, que l'optimiseur Adam suit pour chaque variable. Les variables d'emplacement ne sont enregistrées dans un point de contrôle que si la variable et l'optimiseur sont tous deux enregistrés, donc les bords en pointillés.

L'appel de restore sur un objet tf.train.Checkpoint file d'attente les restaurations demandées, en restaurant les valeurs des variables dès qu'il existe un chemin correspondant à partir de l'objet Checkpoint . Par exemple, vous pouvez charger uniquement le biais du modèle que vous avez défini ci-dessus en reconstruisant un chemin vers celui-ci via le réseau et la couche.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.2704186 3.0526643 3.8114467 3.4453893 4.2802196]

Le graphe de dépendance pour ces nouveaux objets est un sous-graphe beaucoup plus petit du plus grand point de contrôle que vous avez écrit ci-dessus. Il comprend uniquement le biais et un compteur de sauvegarde que tf.train.Checkpoint utilise pour numéroter les points de contrôle.

Visualisation d'un sous-graphe pour la variable de biais

restore renvoie un objet de statut, qui a des assertions facultatives. Tous les objets créés dans le nouveau Checkpoint ont été restaurés, donc status.assert_existing_objects_matched passe.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2a4cbccb38>

Il y a de nombreux objets dans le point de contrôle qui ne correspondent pas, y compris le noyau de la couche et les variables de l'optimiseur. status.assert_consumed ne réussit que si le point de contrôle et le programme correspondent exactement, et lèverait une exception ici.

Restaurations retardées

Layer objets de Layer dans TensorFlow peuvent retarder la création des variables jusqu'à leur premier appel, lorsque des formes d'entrée sont disponibles. Par exemple, la forme du noyau d'un calque Dense dépend à la fois des formes d'entrée et de sortie du calque, et donc la forme de sortie requise en tant qu'argument de constructeur n'est pas suffisamment d'informations pour créer la variable seule. Puisque l'appel d'un Layer lit également la valeur de la variable, une restauration doit avoir lieu entre la création de la variable et sa première utilisation.

Pour prendre en charge cet idiome, tf.train.Checkpoint rétablit les files d'attente qui n'ont pas encore de variable correspondante.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6544    4.6866627 4.729344  4.9574785 4.8010526]]

Inspection manuelle des points de contrôle

tf.train.load_checkpoint renvoie un CheckpointReader qui donne un accès de niveau inférieur au contenu du point de contrôle. Il contient les mappages de la clé de chaque variable à la forme et au type de chaque variable du point de contrôle. La clé d'une variable est son chemin d'objet, comme dans les graphiques affichés ci-dessus.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Donc, si vous êtes intéressé par la valeur de net.l1.kernel vous pouvez obtenir la valeur avec le code suivant:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Il fournit également une méthode get_tensor vous permettant d'inspecter la valeur d'une variable:

reader.get_tensor(key)
array([[4.6544   , 4.6866627, 4.729344 , 4.9574785, 4.8010526]],
      dtype=float32)

Suivi des listes et des dictionnaires

Comme pour les attributions directes d'attributs comme self.l1 = tf.keras.layers.Dense(5) , l'affectation de listes et de dictionnaires à des attributs suivra leur contenu.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Vous remarquerez peut-être des objets wrapper pour les listes et les dictionnaires. Ces wrappers sont des versions vérifiables des structures de données sous-jacentes. Tout comme le chargement basé sur les attributs, ces wrappers restaurent la valeur d'une variable dès qu'elle est ajoutée au conteneur.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Le même suivi est automatiquement appliqué aux sous-classes de tf.keras.Model , et peut être utilisé par exemple pour suivre des listes de couches.

Résumé

Les objets TensorFlow fournissent un mécanisme automatique simple pour enregistrer et restaurer les valeurs des variables qu'ils utilisent.