Points de contrôle de la formation

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

L'expression "Enregistrer un modèle TensorFlow" signifie généralement l'une des deux choses suivantes :

  1. Points de contrôle, OU
  2. Modèle enregistré.

Les points de contrôle capturent la valeur exacte de tous les paramètres (objets tf.Variable ) utilisés par un modèle. Les points de contrôle ne contiennent aucune description du calcul défini par le modèle et ne sont donc généralement utiles que lorsque le code source qui utilisera les valeurs de paramètre enregistrées est disponible.

Le format SavedModel, quant à lui, inclut une description sérialisée du calcul défini par le modèle en plus des valeurs des paramètres (point de contrôle). Les modèles dans ce format sont indépendants du code source qui a créé le modèle. Ils sont donc adaptés au déploiement via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, ou des programmes dans d'autres langages de programmation (les API C, C++, Java, Go, Rust, C# etc. TensorFlow).

Ce guide couvre les API pour l'écriture et la lecture des points de contrôle.

Installer

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Enregistrement à partir des API de formation tf.keras

Voir le guide tf.keras sur l'enregistrement et la restauration des fichiers .

tf.keras.Model.save_weights enregistre un point de contrôle TensorFlow.

net.save_weights('easy_checkpoint')

Rédaction de points de contrôle

L'état persistant d'un modèle TensorFlow est stocké dans des objets tf.Variable . Ceux-ci peuvent être construits directement, mais sont souvent créés via des API de haut niveau comme tf.keras.layers ou tf.keras.Model .

Le moyen le plus simple de gérer des variables consiste à les attacher à des objets Python, puis à référencer ces objets.

Les sous-classes de tf.train.Checkpoint , tf.keras.layers.Layer et tf.keras.Model automatiquement les variables affectées à leurs attributs. L'exemple suivant construit un modèle linéaire simple, puis écrit des points de contrôle qui contiennent des valeurs pour toutes les variables du modèle.

Vous pouvez facilement enregistrer un point de contrôle de modèle avec Model.save_weights .

Point de contrôle manuel

Installer

Pour vous aider à démontrer toutes les fonctionnalités de tf.train.Checkpoint , définissez un jeu de données jouet et une étape d'optimisation :

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Créer les objets de point de contrôle

Utilisez un objet tf.train.Checkpoint pour créer manuellement un point de contrôle, où les objets que vous souhaitez contrôler sont définis en tant qu'attributs sur l'objet.

Un tf.train.CheckpointManager peut également être utile pour gérer plusieurs points de contrôle.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Former et contrôler le modèle

La boucle de formation suivante crée une instance du modèle et d'un optimiseur, puis les rassemble dans un objet tf.train.Checkpoint . Il appelle l'étape d'apprentissage en boucle sur chaque lot de données et écrit périodiquement des points de contrôle sur le disque.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

Restaurer et continuer l'entraînement

Après le premier cycle de formation, vous pouvez réussir un nouveau modèle et manager, mais reprendre la formation exactement là où vous l'avez laissée :

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

L'objet tf.train.CheckpointManager supprime les anciens points de contrôle. Ci-dessus, il est configuré pour ne conserver que les trois points de contrôle les plus récents.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Ces chemins, par exemple './tf_ckpts/ckpt-10' , ne sont pas des fichiers sur le disque. Au lieu de cela, ce sont des préfixes pour un fichier d' index et un ou plusieurs fichiers de données qui contiennent les valeurs des variables. Ces préfixes sont regroupés dans un seul fichier de point de checkpoint ( './tf_ckpts/checkpoint' ) où le CheckpointManager enregistre son état.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mécanique de chargement

TensorFlow fait correspondre les variables aux valeurs à points de contrôle en parcourant un graphe orienté avec des arêtes nommées, à partir de l'objet en cours de chargement. Les noms d'arêtes proviennent généralement des noms d'attributs dans les objets, par exemple le "l1" dans self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint utilise ses noms d'arguments de mots-clés, comme dans "step" dans tf.train.Checkpoint(step=...) .

Le graphique de dépendance de l'exemple ci-dessus ressemble à ceci :

Visualisation du graphique de dépendance pour l'exemple de boucle d'entraînement

L'optimiseur est en rouge, les variables régulières sont en bleu et les variables d'emplacement de l'optimiseur sont en orange. Les autres nœuds, par exemple, représentant le tf.train.Checkpoint , sont en noir.

Les variables d'emplacement font partie de l'état de l'optimiseur, mais sont créées pour une variable spécifique. Par exemple, les arêtes 'm' ci-dessus correspondent à la quantité de mouvement, que l'optimiseur Adam suit pour chaque variable. Les variables d'emplacement ne sont enregistrées dans un point de contrôle que si la variable et l'optimiseur seraient tous deux enregistrés, donc les bords en pointillés.

L'appel de restore sur un objet tf.train.Checkpoint en file d'attente les restaurations demandées, restaurant les valeurs des variables dès qu'il existe un chemin correspondant depuis l'objet Checkpoint . Par exemple, vous pouvez charger uniquement le biais du modèle que vous avez défini ci-dessus en reconstruisant un chemin vers celui-ci à travers le réseau et la couche.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

Le graphique de dépendance pour ces nouveaux objets est un sous-graphique beaucoup plus petit du point de contrôle plus grand que vous avez écrit ci-dessus. Il inclut uniquement le biais et un compteur de sauvegarde que tf.train.Checkpoint utilise pour numéroter les points de contrôle.

Visualisation d'un sous-graphe pour la variable de biais

restore renvoie un objet d'état, qui a des assertions facultatives. Tous les objets créés dans le nouveau point de Checkpoint ont été restaurés, donc status.assert_existing_objects_matched passe.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

De nombreux objets du point de contrôle ne correspondent pas, notamment le noyau de la couche et les variables de l'optimiseur. status.assert_consumed ne passe que si le point de contrôle et le programme correspondent exactement, et lèverait une exception ici.

Restaurations différées

Les objets de Layer dans TensorFlow peuvent reporter la création de variables à leur premier appel, lorsque des formes d'entrée sont disponibles. Par exemple, la forme du noyau d'une couche Dense dépend à la fois des formes d'entrée et de sortie de la couche, et donc la forme de sortie requise en tant qu'argument de constructeur n'est pas suffisante pour créer la variable seule. Étant donné que l'appel d'un Layer lit également la valeur de la variable, une restauration doit avoir lieu entre la création de la variable et sa première utilisation.

Pour prendre en charge cet idiome, tf.train.Checkpoint diffère les restaurations qui n'ont pas encore de variable correspondante.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

Inspection manuelle des points de contrôle

tf.train.load_checkpoint renvoie un CheckpointReader qui donne un accès de niveau inférieur au contenu du point de contrôle. Il contient des mappages de la clé de chaque variable à la forme et au dtype de chaque variable du point de contrôle. La clé d'une variable est son chemin d'accès à l'objet, comme dans les graphiques affichés ci-dessus.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Donc, si vous êtes intéressé par la valeur de net.l1.kernel , vous pouvez obtenir la valeur avec le code suivant :

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Il fournit également une méthode get_tensor permettant d'inspecter la valeur d'une variable :

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

Suivi d'objet

Les points de contrôle enregistrent et restaurent les valeurs des objets tf.Variable en "suivant" toute variable ou objet traçable défini dans l'un de ses attributs. Lors de l'exécution d'une sauvegarde, les variables sont collectées de manière récursive à partir de tous les objets suivis accessibles.

Comme pour les attributions directes d'attributs comme self.l1 = tf.keras.layers.Dense(5) , l'attribution de listes et de dictionnaires aux attributs suivra leur contenu.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Vous remarquerez peut-être des objets wrapper pour les listes et les dictionnaires. Ces wrappers sont des versions checkpointables des structures de données sous-jacentes. Tout comme le chargement basé sur les attributs, ces wrappers restaurent la valeur d'une variable dès qu'elle est ajoutée au conteneur.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Les objets traçables incluent tf.train.Checkpoint , tf.Module et ses sous-classes (par exemple keras.layers.Layer et keras.Model ) et les conteneurs Python reconnus :

  • dict (et collections.OrderedDict )
  • list
  • tuple (et collections.namedtuple , en typing.NamedTuple )

Les autres types de conteneurs ne sont pas compatibles , notamment :

  • collections.defaultdict
  • set

Tous les autres objets Python sont ignorés , y compris :

  • int
  • string
  • float

Résumé

Les objets TensorFlow fournissent un mécanisme automatique simple pour enregistrer et restaurer les valeurs des variables qu'ils utilisent.