Punti di controllo della formazione

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

La frase "Salvataggio di un modello TensorFlow" in genere significa una delle due cose:

  1. Checkpoint, OR
  2. Modello salvato.

Punti di controllo catturano l'esatto valore di tutti i parametri ( tf.Variable oggetti) utilizzati da un modello. I checkpoint non contengono alcuna descrizione del calcolo definito dal modello e quindi sono generalmente utili solo quando è disponibile il codice sorgente che utilizzerà i valori dei parametri salvati.

Il formato SavedModel include invece una descrizione serializzata del calcolo definito dal modello oltre ai valori dei parametri (checkpoint). I modelli in questo formato sono indipendenti dal codice sorgente che ha creato il modello. Sono quindi adatti per l'implementazione tramite TensorFlow Serving, TensorFlow Lite, TensorFlow.js o programmi in altri linguaggi di programmazione (le API C, C++, Java, Go, Rust, C# ecc. TensorFlow).

Questa guida copre le API per la scrittura e la lettura dei checkpoint.

Impostare

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Salvataggio da tf.keras API di formazione

Vedi le tf.keras guidano il salvataggio e il ripristino.

tf.keras.Model.save_weights salva un posto di blocco tensorflow.

net.save_weights('easy_checkpoint')

Scrivere checkpoint

Lo stato persistente di un modello tensorflow viene memorizzato in tf.Variable oggetti. Questi possono essere costruiti direttamente, ma vengono spesso creati tramite API ad alto livello come tf.keras.layers o tf.keras.Model .

Il modo più semplice per gestire le variabili è collegarle agli oggetti Python, quindi fare riferimento a tali oggetti.

Sottoclassi di tf.train.Checkpoint , tf.keras.layers.Layer , e tf.keras.Model monitorare automaticamente le variabili assegnate ai loro attributi. L'esempio seguente costruisce un semplice modello lineare, quindi scrive checkpoint che contengono valori per tutte le variabili del modello.

Si può facilmente salvare un modello checkpoint con Model.save_weights .

Checkpoint manuale

Impostare

Per aiutare a dimostrare tutte le caratteristiche di tf.train.Checkpoint , definire un set di dati giocattolo e fase di ottimizzazione:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Crea gli oggetti checkpoint

Utilizzare un tf.train.Checkpoint oggetto per creare manualmente un posto di blocco, in cui gli oggetti che si desidera al punto di controllo sono impostati come attributi dell'oggetto.

Un tf.train.CheckpointManager può anche essere utile per la gestione di più posti di blocco.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Addestra e controlla il modello

Il seguente ciclo di formazione crea un'istanza del modello e di un ottimizzatore, poi li raccoglie in un tf.train.Checkpoint oggetto. Chiama la fase di addestramento in un ciclo su ogni batch di dati e scrive periodicamente i checkpoint su disco.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.77
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.18
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 16.62
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.16
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.09

Ripristina e continua ad allenarti

Dopo il primo ciclo di formazione puoi passare un nuovo modello e manager, ma riprendere la formazione esattamente da dove avevi interrotto:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.33
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.90
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.62
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.27
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.22

Il tf.train.CheckpointManager oggetto elimina i vecchi posti di blocco. Sopra è configurato per mantenere solo i tre checkpoint più recenti.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Questi percorsi, ad esempio './tf_ckpts/ckpt-10' , non sono i file su disco. Invece sono prefissi per un index di file e uno o più file di dati che contengono i valori delle variabili. Questi prefissi sono raggruppati in un unico checkpoint di file ( './tf_ckpts/checkpoint' ) dove il CheckpointManager salva il suo stato.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Meccanica di caricamento

TensorFlow abbina le variabili ai valori checkpoint attraversando un grafo diretto con bordi denominati, a partire dall'oggetto che viene caricato. Nomi di bordo in genere provengono da nomi degli attributi negli oggetti, ad esempio, il "l1" in self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint utilizza i suoi nomi degli argomenti di parole chiave, come nel "step" in tf.train.Checkpoint(step=...) .

Il grafico delle dipendenze dell'esempio sopra ha il seguente aspetto:

Visualizzazione del grafico delle dipendenze per il ciclo di allenamento di esempio

L'ottimizzatore è in rosso, le variabili regolari sono in blu e le variabili dello slot dell'ottimizzatore sono in arancione. Gli altri nodi, ad esempio, che rappresentano la tf.train.Checkpoint -sono in nero.

Le variabili slot fanno parte dello stato dell'ottimizzatore, ma vengono create per una variabile specifica. Ad esempio i 'm' bordi superiori corrispondono ad impulso, che le tracce ottimizzatore Adam per ciascuna variabile. Le variabili slot vengono salvate in un checkpoint solo se la variabile e l'ottimizzatore sarebbero stati salvati entrambi, quindi i bordi tratteggiati.

Chiamata restore su un tf.train.Checkpoint oggetto code i restauri richiesti, ripristinando i valori delle variabili non appena c'è un percorso di corrispondenza dal Checkpoint dell'oggetto. Ad esempio, puoi caricare solo il bias dal modello che hai definito sopra ricostruendo un percorso ad esso attraverso la rete e il livello.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[1.9851578 3.6375327 2.9331083 3.8130412 4.778274 ]

Il grafico delle dipendenze per questi nuovi oggetti è un sottografo molto più piccolo del checkpoint più grande che hai scritto sopra. Esso include solo la distorsione e un salvataggio contatore che tf.train.Checkpoint utilizza per posti di blocco numero.

Visualizzazione di un sottografo per la variabile bias

restore restituisce un oggetto di stato, che ha asserzioni opzionali. Tutti gli oggetti creati nel nuovo Checkpoint sono stati restaurati, così status.assert_existing_objects_matched passaggi.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f027807e050>

Ci sono molti oggetti nel checkpoint che non corrispondono, incluso il kernel del livello e le variabili dell'ottimizzatore. status.assert_consumed passa solo se il checkpoint e la partita del programma esattamente, e sarebbe un'eccezione qui.

Restauri ritardati

Layer oggetti in tensorflow possono ritardare la creazione di variabili alla loro prima chiamata, quando le forme di ingresso sono disponibili. Ad esempio la forma di un Dense kernel del livello dipende forme di ingresso e uscita sia del livello, e così la forma di uscita richiesta come argomento del costruttore non è abbastanza informazioni per creare la variabile da solo. Dal momento che chiamare un Layer legge anche il valore della variabile, un ripristino deve avvenire tra la creazione della variabile e il suo primo utilizzo.

A sostegno di questo idioma, tf.train.Checkpoint accoda ripristini che non hanno ancora una variabile corrispondente.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6800494 4.607369  4.8321466 4.816245  4.8435326]]

Ispezione manuale dei checkpoint

tf.train.load_checkpoint restituisce un CheckpointReader che dà accesso livello inferiore al contenuto checkpoint. Contiene mappature dalla chiave di ogni variabile, alla forma e al dtype per ogni variabile nel checkpoint. La chiave di una variabile è il percorso dell'oggetto, come nei grafici visualizzati sopra.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Quindi, se siete interessati al valore di net.l1.kernel è possibile ottenere il valore con il seguente codice:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Fornisce inoltre un get_tensor metodo che consente di ispezionare il valore di una variabile:

reader.get_tensor(key)
array([[4.6800494, 4.607369 , 4.8321466, 4.816245 , 4.8435326]],
      dtype=float32)

Monitoraggio di elenchi e dizionari

Come nel caso di affidamenti diretti di attributi come self.l1 = tf.keras.layers.Dense(5) , l'assegnazione di liste e dizionari per gli attributi seguirà il loro contenuto.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Potresti notare oggetti wrapper per elenchi e dizionari. Questi wrapper sono versioni controllabili delle strutture dati sottostanti. Proprio come il caricamento basato sull'attributo, questi wrapper ripristinano il valore di una variabile non appena viene aggiunta al contenitore.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

La stessa inseguimento viene applicato automaticamente a sottoclassi di tf.keras.Model , e può essere utilizzato ad esempio per monitorare elenchi di strati.

Riepilogo

Gli oggetti TensorFlow forniscono un semplice meccanismo automatico per salvare e ripristinare i valori delle variabili che utilizzano.