Sehen Sie sich Keynotes, Produktsitzungen, Workshops und mehr in Google I / O an. Siehe Wiedergabeliste

Trainingskontrollpunkte

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Der Ausdruck "Speichern eines TensorFlow-Modells" bedeutet normalerweise eines von zwei Dingen:

  1. Checkpoints ODER
  2. SavedModel.

Prüfpunkte erfassen den genauen Wert aller von einem Modell verwendeten Parameter ( tf.Variable Objekte). Prüfpunkte enthalten keine Beschreibung der vom Modell definierten Berechnung und sind daher normalerweise nur dann nützlich, wenn Quellcode verfügbar ist, der die gespeicherten Parameterwerte verwendet.

Das SavedModel-Format enthält dagegen zusätzlich zu den Parameterwerten (Prüfpunkt) eine serialisierte Beschreibung der vom Modell definierten Berechnung. Modelle in diesem Format sind unabhängig vom Quellcode, mit dem das Modell erstellt wurde. Sie eignen sich daher für die Bereitstellung über TensorFlow Serving, TensorFlow Lite, TensorFlow.js oder Programme in anderen Programmiersprachen (C, C ++, Java, Go, Rust, C # usw. TensorFlow-APIs).

Dieses Handbuch behandelt APIs zum Schreiben und Lesen von Prüfpunkten.

Einrichten

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Speichern von tf.keras Trainings-APIs

Weitere tf.keras zum Speichern und Wiederherstellen finden Sie im tf.keras Handbuch.

tf.keras.Model.save_weights speichert einen TensorFlow-Prüfpunkt.

net.save_weights('easy_checkpoint')

Checkpoints schreiben

Der persistente Status eines TensorFlow-Modells wird in tf.Variable Objekten gespeichert. Diese können direkt erstellt werden, werden jedoch häufig übertf.keras.layers APIs wietf.keras.layers oder tf.keras.Model .

Der einfachste Weg, Variablen zu verwalten, besteht darin, sie an Python-Objekte anzuhängen und dann auf diese Objekte zu verweisen.

Unterklassen von tf.train.Checkpoint , tf.keras.layers.Layer und tf.keras.Model automatisch Variablen, die ihren Attributen zugewiesen sind. Im folgenden Beispiel wird ein einfaches lineares Modell erstellt und anschließend Prüfpunkte geschrieben, die Werte für alle Variablen des Modells enthalten.

Sie können einen Modellprüfpunkt einfach mit Model.save_weights .

Manuelles Checkpointing

Einrichten

Definieren Sie einen Spielzeugdatensatz und einen Optimierungsschritt, um alle Funktionen von tf.train.Checkpoint zu demonstrieren:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Erstellen Sie die Prüfpunktobjekte

Verwenden Sie ein tf.train.Checkpoint Objekt, um manuell einen Prüfpunkt zu erstellen, bei dem die Objekte, die Sie tf.train.Checkpoint möchten, als Attribute für das Objekt festgelegt werden.

Ein tf.train.CheckpointManager kann auch hilfreich sein, um mehrere Prüfpunkte zu verwalten.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Trainiere und überprüfe das Modell

Die folgende Trainingsschleife erstellt eine Instanz des Modells und eines Optimierers und sammelt sie dann in einem tf.train.Checkpoint Objekt. Es ruft den Trainingsschritt in einer Schleife für jeden Datenstapel auf und schreibt regelmäßig Prüfpunkte auf die Festplatte.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.00
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 22.42
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.86
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 9.40
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.20

Stellen Sie das Training wieder her und setzen Sie es fort

Nach dem ersten Trainingszyklus können Sie ein neues Modell und einen neuen Manager übergeben, aber das Training genau dort beginnen, wo Sie aufgehört haben:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.19
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.66
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.90
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.32
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.34

Das Objekt tf.train.CheckpointManager löscht alte Prüfpunkte. Oben ist es so konfiguriert, dass nur die drei neuesten Prüfpunkte beibehalten werden.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Diese Pfade, z. B. './tf_ckpts/ckpt-10' , sind keine Dateien auf der Festplatte. Stattdessen sind sie Präfixe für eine index und eine oder mehrere Datendateien, die die Variablenwerte enthalten. Diese Präfixe sind in einer einzigen checkpoint ( './tf_ckpts/checkpoint' ) zusammengefasst, in der der CheckpointManager seinen Status speichert.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Lademechanik

TensorFlow vergleicht Variablen mit Prüfpunktwerten, indem ein gerichtetes Diagramm mit benannten Kanten ausgehend vom geladenen Objekt durchlaufen wird. self.l1 = tf.keras.layers.Dense(5) stammen normalerweise von Attributnamen in Objekten, z. B. "l1" in self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint verwendet seine Schlüsselwortargumentnamen wie im "step" in tf.train.Checkpoint(step=...) .

Das Abhängigkeitsdiagramm aus dem obigen Beispiel sieht folgendermaßen aus:

Visualisierung des Abhängigkeitsgraphen für die Beispieltrainingsschleife

Der Optimierer ist rot, reguläre Variablen sind blau und die Optimierungssteckplatzvariablen sind orange. Die anderen Knoten, die beispielsweise den tf.train.Checkpoint sind schwarz.

Slot-Variablen sind Teil des Optimierungsstatus, werden jedoch für eine bestimmte Variable erstellt. Zum Beispiel entsprechen die 'm' Kanten oben dem Impuls, den der Adam-Optimierer für jede Variable verfolgt. Slot-Variablen werden nur dann in einem Prüfpunkt gespeichert, wenn sowohl die Variable als auch der Optimierer gespeichert würden, also die gestrichelten Kanten.

Durch Aufrufen der restore für ein tf.train.Checkpoint Objekt werden die angeforderten Wiederherstellungen in die Warteschlange gestellt und Variablenwerte wiederhergestellt, sobald ein übereinstimmender Pfad vom Checkpoint Objekt vorhanden ist. Beispielsweise können Sie nur die Verzerrung aus dem oben definierten Modell laden, indem Sie einen Pfad durch das Netzwerk und die Schicht zu diesem Modell rekonstruieren.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.2704186 3.0526643 3.8114467 3.4453893 4.2802196]

Das Abhängigkeitsdiagramm für diese neuen Objekte ist ein viel kleinerer Untergraph des größeren Prüfpunkts, den Sie oben geschrieben haben. Es enthält nur den Bias und einen Sicherungszähler, mit dem tf.train.Checkpoint Checkpoints nummeriert.

Visualisierung eines Untergraphen für die Bias-Variable

restore gibt ein Statusobjekt zurück, das optionale Zusicherungen enthält. Alle Objekte in der neu geschaffenen Checkpoint wiederhergestellt wurden, so status.assert_existing_objects_matched Pässe.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2a4cbccb38>

Es gibt viele Objekte im Prüfpunkt, die nicht übereinstimmen, einschließlich des Kernels der Ebene und der Variablen des Optimierers. status.assert_consumed nur übergeben, wenn der Prüfpunkt und das Programm genau übereinstimmen, und würde hier eine Ausnahme status.assert_consumed .

Verzögerte Restaurationen

Layer in TensorFlow können die Erstellung von Variablen bis zum ersten Aufruf verzögern, wenn Eingabeformen verfügbar sind. Beispielsweise hängt die Form des Kernels einer Dense Ebene sowohl von der Eingabe- als auch von der Ausgabeform der Ebene ab. Daher reicht die als Konstruktorargument erforderliche Ausgabeform nicht aus, um die Variable selbst zu erstellen. Da beim Aufrufen einer Layer auch der Wert der Variablen gelesen wird, muss zwischen der Erstellung der Variablen und ihrer ersten Verwendung eine Wiederherstellung erfolgen.

Um diese Redewendung zu unterstützen, werden in tf.train.Checkpoint Warteschlangen Wiederherstellungen durchgeführt, für die noch keine übereinstimmende Variable vorhanden ist.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6544    4.6866627 4.729344  4.9574785 4.8010526]]

Manuelle Inspektion von Kontrollpunkten

tf.train.load_checkpoint gibt einen CheckpointReader , der den Zugriff auf den Checkpoint-Inhalt auf niedrigerer Ebene ermöglicht. Es enthält Zuordnungen vom Schlüssel jeder Variablen zur Form und zum Typ für jede Variable im Prüfpunkt. Der Schlüssel einer Variablen ist ihr Objektpfad, wie in den oben gezeigten Diagrammen.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Wenn Sie also am Wert von net.l1.kernel , können Sie den Wert mit dem folgenden Code net.l1.kernel :

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Es bietet auch eine get_tensor Methode, mit der Sie den Wert einer Variablen überprüfen können:

reader.get_tensor(key)
array([[4.6544   , 4.6866627, 4.729344 , 4.9574785, 4.8010526]],
      dtype=float32)

Listen- und Wörterbuchverfolgung

Wie bei direkten Attributzuweisungen wie self.l1 = tf.keras.layers.Dense(5) wird durch das Zuweisen von Listen und Wörterbüchern zu Attributen deren Inhalt verfolgt.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Möglicherweise bemerken Sie Wrapper-Objekte für Listen und Wörterbücher. Diese Wrapper sind checkpointable Versionen der zugrunde liegenden Datenstrukturen. Genau wie beim attributbasierten Laden stellen diese Wrapper den Wert einer Variablen wieder her, sobald sie dem Container hinzugefügt wird.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Das gleiche Tracking wird automatisch auf Unterklassen von tf.keras.Model angewendet und kann beispielsweise zum Verfolgen von tf.keras.Model verwendet werden.

Zusammenfassung

TensorFlow-Objekte bieten einen einfachen automatischen Mechanismus zum Speichern und Wiederherstellen der Werte der von ihnen verwendeten Variablen.