Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Punti di controllo della formazione

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica notebook

La frase "Salvataggio di un modello TensorFlow" in genere significa due cose:

  1. Checkpoint, OR
  2. SavedModel.

I checkpoint acquisiscono il valore esatto di tutti i parametri (oggetti tf.Variable ) utilizzati da un modello. I checkpoint non contengono alcuna descrizione del calcolo definito dal modello e quindi sono generalmente utili solo quando è disponibile il codice sorgente che utilizzerà i valori dei parametri salvati.

Il formato SavedModel invece include una descrizione serializzata del calcolo definito dal modello oltre ai valori dei parametri (checkpoint). I modelli in questo formato sono indipendenti dal codice sorgente che ha creato il modello. Sono quindi adatti per la distribuzione tramite TensorFlow Serving, TensorFlow Lite, TensorFlow.js o programmi in altri linguaggi di programmazione (le API C, C ++, Java, Go, Rust, C # ecc. TensorFlow).

Questa guida copre le API per la scrittura e la lettura dei checkpoint.

Impostare

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Risparmio dalle API di formazione tf.keras

Consulta la guida tf.keras su salvataggio e ripristino.

tf.keras.Model.save_weights salva un checkpoint TensorFlow.

net.save_weights('easy_checkpoint')

Scrivere checkpoint

Lo stato persistente di un modello TensorFlow viene archiviato negli oggetti tf.Variable . Questi possono essere costruiti direttamente, ma spesso vengono creati tramite API di alto livello come tf.keras.layers o tf.keras.Model .

Il modo più semplice per gestire le variabili è collegarle agli oggetti Python, quindi fare riferimento a tali oggetti.

Le sottoclassi di tf.train.Checkpoint , tf.keras.layers.Layer e tf.keras.Model traccia automaticamente delle variabili assegnate ai loro attributi. L'esempio seguente costruisce un semplice modello lineare, quindi scrive i punti di controllo che contengono i valori per tutte le variabili del modello.

Puoi facilmente salvare un punto di controllo del modello con Model.save_weights

Checkpoint manuale

Impostare

Per aiutare a dimostrare tutte le funzionalità di tf.train.Checkpoint definire un set di dati del giocattolo e una fase di ottimizzazione:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Crea gli oggetti checkpoint

Per creare manualmente un checkpoint avrai bisogno di un oggetto tf.train.Checkpoint . Il punto in cui gli oggetti che si desidera controllare sono impostati come attributi sull'oggetto.

Un tf.train.CheckpointManager può anche essere utile per la gestione di più checkpoint.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Addestra e controlla il modello

Il seguente ciclo di addestramento crea un'istanza del modello e di un ottimizzatore, quindi li raccoglie in un oggetto tf.train.Checkpoint . Richiama la fase di addestramento in un ciclo su ogni batch di dati e scrive periodicamente i checkpoint su disco.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 32.40
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 25.82
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 19.26
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 12.77
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 6.48

Ripristina e continua la formazione

Dopo il primo puoi passare un nuovo modello e manager, ma riprendi l'addestramento esattamente da dove l'avevi interrotto:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.85
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.88
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.44
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.41
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.25

L'oggetto tf.train.CheckpointManager elimina i vecchi checkpoint. Sopra è configurato per mantenere solo i tre checkpoint più recenti.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Questi percorsi, ad esempio './tf_ckpts/ckpt-10' , non sono file su disco. Sono invece prefissi per un file di index e uno o più file di dati che contengono i valori delle variabili. Questi prefissi sono raggruppati in un unico file di checkpoint ( './tf_ckpts/checkpoint' ) in cui CheckpointManager salva il suo stato.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Caricamento meccanico

TensorFlow abbina le variabili ai valori con checkpoint attraversando un grafico diretto con bordi denominati, a partire dall'oggetto che viene caricato. I nomi degli self.l1 = tf.keras.layers.Dense(5) in genere derivano dai nomi degli attributi negli oggetti, ad esempio "l1" in self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint utilizza i nomi degli argomenti delle parole chiave, come nel "step" in tf.train.Checkpoint(step=...) .

Il grafico delle dipendenze dell'esempio sopra ha questo aspetto:

Visualizzazione del grafico delle dipendenze per il ciclo di addestramento di esempio

Con l'ottimizzatore in rosso, le variabili regolari in blu e le variabili slot dell'ottimizzatore in arancione. Gli altri nodi, ad esempio che rappresentano il tf.train.Checkpoint , sono neri.

Le variabili di slot fanno parte dello stato dell'ottimizzatore, ma vengono create per una variabile specifica. Ad esempio, i bordi 'm' sopra corrispondono alla quantità di moto, che l'ottimizzatore di Adam traccia per ciascuna variabile. Le variabili di slot vengono salvate in un checkpoint solo se la variabile e l'ottimizzatore vengono salvati entrambi, quindi i bordi tratteggiati.

La chiamata a restore() su un oggetto tf.train.Checkpoint accoda i ripristini richiesti, ripristinando i valori delle variabili non appena c'è un percorso corrispondente dall'oggetto Checkpoint . Ad esempio, puoi caricare solo il bias dal modello che abbiamo definito sopra ricostruendo un percorso ad esso attraverso la rete e il livello.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[3.4461102 3.030825  4.4315968 3.5077076 4.7258596]

Il grafico delle dipendenze per questi nuovi oggetti è un sottografo molto più piccolo del checkpoint più grande che abbiamo scritto sopra. Include solo il bias e un contatore di salvataggio che tf.train.Checkpoint utilizza per numerare i checkpoint.

Visualizzazione di un sottografo per la variabile bias

restore() restituisce un oggetto di stato, che ha asserzioni opzionali. Tutti gli oggetti creati nel nuovo Checkpoint sono stati ripristinati, quindi status.assert_existing_objects_matched() passa.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fec144bd080>

Ci sono molti oggetti nel checkpoint che non corrispondono, incluso il kernel del livello e le variabili dell'ottimizzatore. status.assert_consumed() passa solo se il checkpoint e il programma corrispondono esattamente, e qui viene status.assert_consumed() un'eccezione.

Restauri ritardati

Layer oggetti Layer in TensorFlow possono ritardare la creazione di variabili fino alla prima chiamata, quando sono disponibili le forme di input. Ad esempio, la forma del kernel di un livello Dense dipende sia dalle forme di input che di output del livello, quindi la forma di output richiesta come argomento del costruttore non è sufficiente per creare la variabile da sola. Poiché la chiamata a un Layer legge anche il valore della variabile, deve avvenire un ripristino tra la creazione della variabile e il suo primo utilizzo.

Per supportare questo idioma, tf.train.Checkpoint ripristino delle code che non hanno ancora una variabile corrispondente.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.4598393 4.677273  4.655946  4.926899  4.79748  ]]

Ispezione manuale dei posti di blocco

tf.train.list_variables elenca le chiavi dei checkpoint e le forme delle variabili in un checkpoint. Le chiavi del punto di controllo sono percorsi nel grafico visualizzato sopra.

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

Monitoraggio di elenchi e dizionari

Come per le assegnazioni dirette di attributi come self.l1 = tf.keras.layers.Dense(5) , l'assegnazione di elenchi e dizionari agli attributi terrà traccia del loro contenuto.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Potresti notare oggetti wrapper per elenchi e dizionari. Questi wrapper sono versioni controllabili delle strutture dati sottostanti. Proprio come il caricamento basato sugli attributi, questi wrapper ripristinano il valore di una variabile non appena viene aggiunta al contenitore.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Lo stesso tracciamento viene applicato automaticamente alle sottoclassi di tf.keras.Model e può essere utilizzato ad esempio per tracciare elenchi di livelli.

Sommario

Gli oggetti TensorFlow forniscono un semplice meccanismo automatico per salvare e ripristinare i valori delle variabili che utilizzano.