Hai una domanda? Connettiti con la community al forum TensorFlow Visita il forum

Punti di controllo dell'addestramento

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica taccuino

La frase "Salvataggio di un modello TensorFlow" in genere significa due cose:

  1. Checkpoint, OR
  2. SavedModel.

I checkpoint acquisiscono il valore esatto di tutti i parametri (oggetti tf.Variable ) utilizzati da un modello. I checkpoint non contengono alcuna descrizione del calcolo definito dal modello e quindi sono in genere utili solo quando è disponibile il codice sorgente che utilizzerà i valori dei parametri salvati.

Il formato SavedModel invece include una descrizione serializzata del calcolo definito dal modello oltre ai valori dei parametri (checkpoint). I modelli in questo formato sono indipendenti dal codice sorgente che ha creato il modello. Sono quindi adatti per la distribuzione tramite TensorFlow Serving, TensorFlow Lite, TensorFlow.js o programmi in altri linguaggi di programmazione (le API C, C ++, Java, Go, Rust, C # ecc TensorFlow).

Questa guida copre le API per la scrittura e la lettura dei checkpoint.

Impostare

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Risparmio dalle API di formazione tf.keras

Consulta la guida tf.keras su salvataggio e ripristino.

tf.keras.Model.save_weights salva un checkpoint TensorFlow.

net.save_weights('easy_checkpoint')

Scrivere checkpoint

Lo stato persistente di un modello TensorFlow viene archiviato negli oggetti tf.Variable . Questi possono essere costruiti direttamente, ma spesso vengono creati tramite API di alto livello cometf.keras.layers o tf.keras.Model .

Il modo più semplice per gestire le variabili è collegarle agli oggetti Python, quindi fare riferimento a tali oggetti.

Le sottoclassi di tf.train.Checkpoint , tf.keras.layers.Layer e tf.keras.Model traccia automaticamente delle variabili assegnate ai loro attributi. L'esempio seguente costruisce un semplice modello lineare, quindi scrive i punti di controllo che contengono valori per tutte le variabili del modello.

Puoi facilmente salvare un punto di controllo del modello con Model.save_weights .

Checkpoint manuale

Impostare

Per aiutare a dimostrare tutte le funzionalità di tf.train.Checkpoint , definire un set di dati del giocattolo e una fase di ottimizzazione:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Crea gli oggetti checkpoint

Utilizzare un oggetto tf.train.Checkpoint per creare manualmente un checkpoint, in cui gli oggetti che si desidera controllare siano impostati come attributi sull'oggetto.

Un tf.train.CheckpointManager può anche essere utile per la gestione di più checkpoint.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Addestra e controlla il modello

Il seguente ciclo di addestramento crea un'istanza del modello e di un ottimizzatore, quindi li raccoglie in un oggetto tf.train.Checkpoint . Richiama la fase di addestramento in un ciclo su ogni batch di dati e scrive periodicamente i checkpoint su disco.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.00
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 22.42
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.86
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 9.40
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.20

Ripristina e continua la formazione

Dopo il primo ciclo di formazione puoi passare un nuovo modello e manager, ma riprendi la formazione esattamente da dove l'avevi interrotta:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.19
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.66
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.90
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.32
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.34

L'oggetto tf.train.CheckpointManager elimina i vecchi checkpoint. Sopra è configurato per mantenere solo i tre checkpoint più recenti.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Questi percorsi, ad esempio './tf_ckpts/ckpt-10' , non sono file su disco. Sono invece prefissi per un file di index e uno o più file di dati che contengono i valori delle variabili. Questi prefissi sono raggruppati in un unico file di checkpoint ( './tf_ckpts/checkpoint' ) in cui CheckpointManager salva il suo stato.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Meccanica di caricamento

TensorFlow abbina le variabili ai valori con checkpoint attraversando un grafico diretto con bordi denominati, a partire dall'oggetto caricato. I nomi dei bordi derivano tipicamente dai nomi degli attributi negli oggetti, ad esempio "l1" in self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint utilizza i nomi degli argomenti delle parole chiave, come nel "step" in tf.train.Checkpoint(step=...) .

Il grafico delle dipendenze dell'esempio sopra ha questo aspetto:

Visualizzazione del grafico delle dipendenze per il ciclo di addestramento di esempio

L'ottimizzatore è in rosso, le variabili regolari sono in blu e le variabili dello slot dell'ottimizzatore sono in arancione. Gli altri nodi, ad esempio, che rappresentano tf.train.Checkpoint sono in nero.

Le variabili di slot fanno parte dello stato dell'ottimizzatore, ma vengono create per una variabile specifica. Ad esempio, i bordi 'm' sopra corrispondono alla quantità di moto, che l'ottimizzatore di Adam tiene traccia per ciascuna variabile. Le variabili di slot vengono salvate in un checkpoint solo se la variabile e l'ottimizzatore vengono salvati entrambi, quindi i bordi tratteggiati.

La chiamata al restore su un oggetto tf.train.Checkpoint accoda i ripristini richiesti, ripristinando i valori delle variabili non appena c'è un percorso corrispondente dall'oggetto Checkpoint . Ad esempio, puoi caricare solo il bias dal modello definito sopra ricostruendo un percorso ad esso attraverso la rete e il livello.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.2704186 3.0526643 3.8114467 3.4453893 4.2802196]

Il grafico delle dipendenze per questi nuovi oggetti è un sottografo molto più piccolo del checkpoint più grande che hai scritto sopra. Include solo il bias e un contatore di salvataggio che tf.train.Checkpoint utilizza per numerare i checkpoint.

Visualizzazione di un sottografo per la variabile bias

restore restituisce un oggetto di stato, che ha asserzioni opzionali. Tutti gli oggetti creati nel nuovo Checkpoint sono stati ripristinati, quindi status.assert_existing_objects_matched passa.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2a4cbccb38>

Ci sono molti oggetti nel checkpoint che non corrispondono, incluso il kernel del livello e le variabili dell'ottimizzatore. status.assert_consumed viene superato solo se il checkpoint e il programma corrispondono esattamente e qui viene status.assert_consumed un'eccezione.

Restauri ritardati

Layer oggetti Layer in TensorFlow possono ritardare la creazione di variabili fino alla prima chiamata, quando sono disponibili le forme di input. Ad esempio, la forma del kernel di un livello Dense dipende sia dalle forme di input che di output del livello, quindi la forma di output richiesta come argomento del costruttore non è sufficiente per creare la variabile da sola. Poiché la chiamata di un Layer legge anche il valore della variabile, deve avvenire un ripristino tra la creazione della variabile e il suo primo utilizzo.

Per supportare questo idioma, tf.train.Checkpoint ripristina le code che non hanno ancora una variabile corrispondente.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6544    4.6866627 4.729344  4.9574785 4.8010526]]

Ispezione manuale dei punti di controllo

tf.train.load_checkpoint restituisce un CheckpointReader che fornisce l'accesso di livello inferiore ai contenuti del checkpoint. Contiene le mappature dalla chiave di ciascuna variabile, alla forma e al dtype per ciascuna variabile nel checkpoint. La chiave di una variabile è il percorso dell'oggetto, come nei grafici visualizzati sopra.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Quindi, se sei interessato al valore di net.l1.kernel puoi ottenere il valore con il seguente codice:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Fornisce inoltre un metodo get_tensor che consente di ispezionare il valore di una variabile:

reader.get_tensor(key)
array([[4.6544   , 4.6866627, 4.729344 , 4.9574785, 4.8010526]],
      dtype=float32)

Monitoraggio di elenchi e dizionari

Come per le assegnazioni dirette di attributi come self.l1 = tf.keras.layers.Dense(5) , l'assegnazione di elenchi e dizionari agli attributi terrà traccia del loro contenuto.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Potresti notare oggetti wrapper per elenchi e dizionari. Questi wrapper sono versioni controllabili delle strutture dati sottostanti. Proprio come il caricamento basato sugli attributi, questi wrapper ripristinano il valore di una variabile non appena viene aggiunta al contenitore.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Lo stesso tracciamento viene applicato automaticamente alle sottoclassi di tf.keras.Model e può essere utilizzato ad esempio per tracciare elenchi di livelli.

Sommario

Gli oggetti TensorFlow forniscono un semplice meccanismo automatico per salvare e ripristinare i valori delle variabili che utilizzano.