Points de contrôle de la formation

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

L'expression "Enregistrer un modèle TensorFlow" signifie généralement l'une des deux choses suivantes :

  1. Points de contrôle, OU
  2. Modèle enregistré.

Les postes de contrôle saisissent la valeur exacte de tous les paramètres ( tf.Variable objets) utilisés par un modèle. Les points de contrôle ne contiennent aucune description du calcul défini par le modèle et ne sont donc généralement utiles que lorsque le code source qui utilisera les valeurs de paramètres enregistrées est disponible.

Le format SavedModel, quant à lui, comprend une description sérialisée du calcul défini par le modèle en plus des valeurs des paramètres (checkpoint). Les modèles dans ce format sont indépendants du code source qui a créé le modèle. Ils sont donc adaptés au déploiement via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, ou des programmes dans d'autres langages de programmation (les API C, C++, Java, Go, Rust, C# etc. TensorFlow).

Ce guide couvre les API pour l'écriture et la lecture des points de contrôle.

Installer

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Enregistrement de tf.keras API de formation

Voir les tf.keras guide sur la sauvegarde et la restauration.

tf.keras.Model.save_weights enregistre un point de contrôle de tensorflow.

net.save_weights('easy_checkpoint')

Écriture de points de contrôle

L'état persistant d'un modèle de tensorflow est stocké dans tf.Variable objets. Ceux - ci peuvent être construits directement, mais sont souvent créés par des API de haut niveau comme tf.keras.layers ou tf.keras.Model .

Le moyen le plus simple de gérer les variables est de les attacher à des objets Python, puis de référencer ces objets.

Sous - classes de tf.train.Checkpoint , tf.keras.layers.Layer et tf.keras.Model suivre automatiquement les variables affectées à leurs attributs. L'exemple suivant construit un modèle linéaire simple, puis écrit des points de contrôle qui contiennent des valeurs pour toutes les variables du modèle.

Vous pouvez facilement enregistrer un modèle de point de contrôle avec Model.save_weights .

Point de contrôle manuel

Installer

Pour aider à démontrer toutes les fonctionnalités de tf.train.Checkpoint , définir un ensemble de données de jouet et l' étape d'optimisation:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Créer les objets de point de contrôle

Utilisez un tf.train.Checkpoint objet pour créer manuellement un point de contrôle, où les objets à point de contrôle sont définis comme des attributs de l'objet.

Un tf.train.CheckpointManager peut également être utile pour gérer plusieurs points de contrôle.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Former et contrôler le modèle

La boucle de formation suivante crée une instance du modèle et d'un optimiseur, puis les rassemble dans un tf.train.Checkpoint objet. Il appelle l'étape d'apprentissage en boucle sur chaque lot de données et écrit périodiquement des points de contrôle sur le disque.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.77
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.18
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 16.62
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.16
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.09

Restaurer et continuer la formation

Après le premier cycle de formation, vous pouvez passer un nouveau modèle et manager, mais reprendre la formation exactement là où vous vous étiez arrêté :

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.33
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.90
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.62
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.27
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.22

Le tf.train.CheckpointManager objet supprime les anciens points de contrôle. Ci-dessus, il est configuré pour ne conserver que les trois points de contrôle les plus récents.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Ces chemins, par exemple './tf_ckpts/ckpt-10' , sont des fichiers pas sur le disque. Au contraire , ils sont des préfixes pour un index fichier et un ou plusieurs fichiers de données qui contiennent les valeurs des variables. Ces préfixes sont regroupés dans un seul checkpoint de './tf_ckpts/checkpoint' CheckpointManager checkpoint fichier ( './tf_ckpts/checkpoint' ) où le CheckpointManager enregistre son état.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mécanique de chargement

TensorFlow fait correspondre les variables aux valeurs de point de contrôle en parcourant un graphe orienté avec des arêtes nommées, à partir de l'objet en cours de chargement. Les noms Edge sont généralement des noms d'attributs dans les objets, par exemple le "l1" dans self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint utilise ses noms d'argument mot - clé, comme dans la "step" dans tf.train.Checkpoint(step=...) .

Le graphique de dépendance de l'exemple ci-dessus ressemble à ceci :

Visualisation du graphe de dépendance pour l'exemple de boucle d'entraînement

L'optimiseur est en rouge, les variables normales sont en bleu et les variables d'emplacement de l'optimiseur sont en orange. Les autres noeuds, par exemple, représentant le tf.train.Checkpoint -Y en noir.

Les variables d'emplacement font partie de l'état de l'optimiseur, mais sont créées pour une variable spécifique. Par exemple , les 'm' bords ci - dessus correspondent à impulsion, que les pistes d'optimisation Adam pour chaque variable. Les variables d'emplacement ne sont enregistrées dans un point de contrôle que si la variable et l'optimiseur seraient tous les deux enregistrés, donc les bords en pointillés.

Appel de restore sur une tf.train.Checkpoint files d'attente objet les restaurations demandées, la restauration des valeurs de variables dès qu'il ya un chemin correspondant du Checkpoint objet. Par exemple, vous pouvez charger uniquement le biais du modèle que vous avez défini ci-dessus en reconstruisant un chemin vers celui-ci à travers le réseau et la couche.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[1.9851578 3.6375327 2.9331083 3.8130412 4.778274 ]

Le graphique de dépendance de ces nouveaux objets est un sous-graphique beaucoup plus petit du point de contrôle plus grand que vous avez écrit ci-dessus. Il ne comprend que le parti pris et une sauvegarde contre que tf.train.Checkpoint utilise pour les postes de contrôle numériques.

Visualisation d'un sous-graphe pour la variable de biais

restore retourne un objet d'état, qui a des affirmations en option. Tous les objets créés dans le nouveau Checkpoint de status.assert_existing_objects_matched Checkpoint ont été restaurés, si status.assert_existing_objects_matched passe.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f027807e050>

De nombreux objets du point de contrôle ne correspondent pas, y compris le noyau de la couche et les variables de l'optimiseur. status.assert_consumed ne passe si le point de contrôle et le match du programme exactement, et jetterait une exception ici.

Restaurations différées

Layer des objets dans tensorflow peuvent retarder la création de variables à leur premier appel, lorsque des formes d'entrée sont disponibles. Par exemple , la forme d'un Dense noyau de couche dépend à la fois d' entrée de la couche et des formes de sortie, et ainsi la forme de sortie requise comme argument du constructeur est pas assez d' informations pour créer la variable lui - même. Depuis l' appel d' une Layer lit également la valeur de la variable, une restauration doit se produire entre la création de la variable et sa première utilisation.

Pour soutenir cet idiome, tf.train.Checkpoint files d' attente qui n'ont des restaurations pas encore une variable correspondant.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6800494 4.607369  4.8321466 4.816245  4.8435326]]

Inspection manuelle des points de contrôle

tf.train.load_checkpoint retourne un CheckpointReader qui donne accès au niveau inférieur au contenu du point de contrôle. Il contient des mappages de la clé de chaque variable, à la forme et au type de chaque variable dans le point de contrôle. La clé d'une variable est son chemin d'objet, comme dans les graphiques affichés ci-dessus.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Donc , si vous êtes intéressé à la valeur de net.l1.kernel vous pouvez obtenir la valeur avec le code suivant:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Il fournit également une get_tensor méthode permettant de contrôler la valeur d'une variable:

reader.get_tensor(key)
array([[4.6800494, 4.607369 , 4.8321466, 4.816245 , 4.8435326]],
      dtype=float32)

Suivi des listes et des dictionnaires

Comme pour les affectations d'attributs directs comme self.l1 = tf.keras.layers.Dense(5) , l' attribution des listes et des dictionnaires aux attributs permettra de suivre leur contenu.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Vous remarquerez peut-être des objets wrapper pour les listes et les dictionnaires. Ces enveloppes sont des versions point de contrôle des structures de données sous-jacentes. Tout comme le chargement basé sur les attributs, ces wrappers restaurent la valeur d'une variable dès qu'elle est ajoutée au conteneur.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Le même suivi est automatiquement appliqué à des sous - classes de tf.keras.Model , et peut être utilisée par exemple pour suivre des listes de couches.

Sommaire

Les objets TensorFlow fournissent un mécanisme automatique simple pour enregistrer et restaurer les valeurs des variables qu'ils utilisent.