Trainingskontrollpunkte

Auf TensorFlow.org ansehen In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Der Ausdruck "Speichern eines TensorFlow-Modells" bedeutet normalerweise eines von zwei Dingen:

  1. Kontrollpunkte, ODER
  2. Gespeichertes Modell.

Checkpoints erfassen den exakten Wert aller Parameter ( tf.Variable Objekte) durch ein Modell verwendet. Prüfpunkte enthalten keine Beschreibung der vom Modell definierten Berechnung und sind daher normalerweise nur nützlich, wenn Quellcode verfügbar ist, der die gespeicherten Parameterwerte verwendet.

Das SavedModel-Format hingegen enthält neben den Parameterwerten (Checkpoint) eine serialisierte Beschreibung der vom Modell definierten Berechnung. Modelle in diesem Format sind unabhängig von dem Quellcode, der das Modell erstellt hat. Sie eignen sich daher für die Bereitstellung über TensorFlow Serving, TensorFlow Lite, TensorFlow.js oder Programme in anderen Programmiersprachen (die C, C++, Java, Go, Rust, C# etc. TensorFlow APIs).

In diesem Handbuch werden APIs zum Schreiben und Lesen von Prüfpunkten behandelt.

Aufstellen

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Speichern von tf.keras Training APIs

Siehe die tf.keras Führung auf Speichern und Wiederherstellen.

tf.keras.Model.save_weights spart TensorFlow Checkpoint.

net.save_weights('easy_checkpoint')

Kontrollpunkte schreiben

Der anhaltende Zustand eines TensorFlow Modell wird in gespeichert tf.Variable Objekten. Diese können direkt aufgebaut werden, werden jedoch oft durch High-Level - APIs wie geschaffen tf.keras.layers oder tf.keras.Model .

Der einfachste Weg, Variablen zu verwalten, besteht darin, sie an Python-Objekte anzuhängen und dann auf diese Objekte zu verweisen.

Subklassen von tf.train.Checkpoint , tf.keras.layers.Layer und tf.keras.Model automatisch zu verfolgen , um ihre zugeordneten Attribute Variablen. Im folgenden Beispiel wird ein einfaches lineares Modell erstellt und dann Prüfpunkte geschrieben, die Werte für alle Variablen des Modells enthalten.

Sie können ganz einfach ein Modell-Kontrollpunkt mit speichern Model.save_weights .

Manuelles Checkpointing

Aufstellen

Um Hilfe zeigen alle die Eigenschaften tf.train.Checkpoint , definiert einen Spielzeug - Datensatz und Optimierungsschritt:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Erstellen Sie die Prüfpunktobjekte

Verwenden Sie ein tf.train.Checkpoint Objekt , um manuell einen Kontrollpunkt zu erstellen, in dem die Objekte , die Sie Kontrollpunkt wollen als Attribute für das Objekt festgelegt.

Ein tf.train.CheckpointManager kann auch hilfreich sein , mehrere Checkpoints für die Verwaltung.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Trainieren und überprüfen Sie das Modell

Die folgende Trainingsschleife erzeugt eine Instanz des Modells und einen Optimierers, sammelt sie dann in ein tf.train.Checkpoint Objekt. Es ruft den Trainingsschritt in einer Schleife für jeden Datenstapel auf und schreibt regelmäßig Prüfpunkte auf die Festplatte.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.67
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.09
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 16.53
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.10
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.39

Wiederherstellung und Fortführung des Trainings

Nach dem ersten Trainingszyklus können Sie ein neues Model und einen neuen Manager bestehen, aber das Training genau dort fortsetzen, wo Sie aufgehört haben:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.64
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.17
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.69
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.19

Das tf.train.CheckpointManager Objekt löscht altes Checkpoints. Oben ist es so konfiguriert, dass nur die drei neuesten Prüfpunkte beibehalten werden.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Diese Wege, zB './tf_ckpts/ckpt-10' , werden die Dateien nicht auf der Festplatte. Stattdessen sind sie Präfixe für eine index - Datei und eine oder mehr Datendateien , die die Variablenwert enthalten. Diese Präfixe werden zusammengefasst in einem einzigen checkpoint - Datei ( './tf_ckpts/checkpoint' ) , wo der CheckpointManager seinen Zustand speichert.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Lademechanik

TensorFlow gleicht Variablen mit Prüfpunktwerten ab, indem ein gerichteter Graph mit benannten Kanten ab dem geladenen Objekt durchlaufen wird. Edge - Name kommt typischerweise von Attributnamen in Objekten, zum Beispiel des "l1" in self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint verwendet seine Schlüsselwort Argumentnamen, wie in der "step" in tf.train.Checkpoint(step=...) .

Der Abhängigkeitsgraph aus dem obigen Beispiel sieht so aus:

Visualisierung des Abhängigkeitsgraphen für die Beispiel-Trainingsschleife

Der Optimierer ist rot, reguläre Variablen sind blau und die Slot-Variablen des Optimierers sind orange. Die anderen Knoten, zum Beispiel der vertretende tf.train.Checkpoint -sind in schwarz.

Slot-Variablen sind Teil des Status des Optimierers, werden jedoch für eine bestimmte Variable erstellt. Zum Beispiel kann die 'm' Kanten oben entspricht Dynamik, welche die ADAM - Optimierer Spuren für jede Variable. Slot-Variablen werden nur dann in einem Checkpoint gespeichert, wenn sowohl die Variable als auch der Optimierer gespeichert würden, also die gestrichelten Kanten.

Aufruf restore auf einer tf.train.Checkpoint Objekt Warteschlangen der angeforderten Restaurationen, die Wiederherstellung Variablenwert, sobald ein passender Weg von dem ist Checkpoint - Objekt. Sie können beispielsweise nur den Bias aus dem oben definierten Modell laden, indem Sie einen Pfad zu diesem durch das Netzwerk und den Layer rekonstruieren.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.3119967 2.088805  3.9098527 3.9504364 4.7226586]

Der Abhängigkeitsgraph für diese neuen Objekte ist ein viel kleinerer Untergraph des größeren Checkpoints, den Sie oben geschrieben haben. Es enthält nur die Vorspannung und einen Zähler speichern , dass tf.train.Checkpoint Anzahl Checkpoints verwendet.

Visualisierung eines Untergraphen für die Bias-Variable

restore gibt einen Statusobjekt, das optional Behauptungen hat. Alle Objekte in der neu geschaffenen Checkpoint wiederhergestellt wurden, so status.assert_existing_objects_matched Pässe.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f967c028e10>

Es gibt viele Objekte im Prüfpunkt, die nicht übereinstimmen, einschließlich des Kernels der Schicht und der Variablen des Optimierers. status.assert_consumed geht nur , wenn der Kontrollpunkt und das Programm genau übereinstimmen, und würde hier eine Ausnahme werfen.

Verzögerte Restaurationen

Layer Objekte in TensorFlow können die Erstellung von Variablen zu ihrem ersten Anruf verzögern, wenn die Eingänge Formen zur Verfügung stehen. Zum Beispiel der Form eines Dense kernel Layers hängt sowohl von der Eingangs- und Ausgang Formen der Schicht, und so die Ausgangsform als Konstruktor Argument erforderlich ist , nicht genügend Informationen , um die Variable auf seinem eigenen zu erstellen. Da ruft eine Layer auch den Wert der Variablen liest, muss zwischen den variablen Schöpfung und seinem ersten Einsatz eine Wiederherstellung geschehen.

Um dieses Idiom zu unterstützen, tf.train.Checkpoint Warteschlangen Wiederherstellungen , die noch nicht eine passende Variable haben.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.603108  4.814235  4.7161555 4.818163  4.8451676]]

Kontrollpunkte manuell prüfen

tf.train.load_checkpoint gibt einen CheckpointReader , die untere Ebene Zugang zu dem Kontrollpunkt Inhalt gibt. Es enthält Zuordnungen vom Schlüssel jeder Variablen zur Form und zum dtype für jede Variable im Prüfpunkt. Der Schlüssel einer Variablen ist ihr Objektpfad, wie in den oben gezeigten Diagrammen.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Also , wenn Sie in dem Wert interessiert sind net.l1.kernel können Sie den Wert mit dem folgenden Code erhalten:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Es bietet auch eine get_tensor Methode in dem Sie den Wert einer Variablen überprüfen:

reader.get_tensor(key)
array([[4.603108 , 4.814235 , 4.7161555, 4.818163 , 4.8451676]],
      dtype=float32)

Listen- und Wörterbuchverfolgung

Wie bei direkter Attributzuordnungen wie self.l1 = tf.keras.layers.Dense(5) , werden Listen und Wörterbücher Attribute zuweisen deren Inhalte verfolgen.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Möglicherweise bemerken Sie Wrapper-Objekte für Listen und Wörterbücher. Diese Wrapper sind prüfpunktfähige Versionen der zugrunde liegenden Datenstrukturen. Genau wie beim attributbasierten Laden stellen diese Wrapper den Wert einer Variablen wieder her, sobald sie dem Container hinzugefügt wird.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Das gleiche Tracking wird automatisch auf Subklassen der angewandten tf.keras.Model , und kann beispielsweise verwendet werden Listen von Schichten zu verfolgen.

Zusammenfassung

TensorFlow-Objekte bieten einen einfachen automatischen Mechanismus zum Speichern und Wiederherstellen der Werte von Variablen, die sie verwenden.