Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Checkpoint di addestramento

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica il quaderno

La frase "Salvataggio di un modello TensorFlow" in genere significa una delle due cose:

  1. Checkpoint, OR
  2. SavedModel.

I checkpoint acquisiscono il valore esatto di tutti i parametri ( tf.Variable oggetti tf.Variable ) utilizzati da un modello. I checkpoint non contengono alcuna descrizione del calcolo definito dal modello e pertanto sono generalmente utili solo quando è disponibile il codice sorgente che utilizzerà i valori dei parametri salvati.

Il formato SavedModel include invece una descrizione serializzata del calcolo definito dal modello oltre ai valori dei parametri (checkpoint). I modelli in questo formato sono indipendenti dal codice sorgente che ha creato il modello. Sono quindi adatti per la distribuzione tramite TensorFlow Serving, TensorFlow Lite, TensorFlow.js o programmi in altri linguaggi di programmazione (API TensorFlow C, C ++, Java, Go, Rust, C ecc.).

Questa guida copre le API per la scrittura e la lettura di checkpoint.

Impostare

 import tensorflow as tf
 
 class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
 
 net = Net()
 

Salvataggio dalle API di training di tf.keras

Consulta la guida di tf.keras sul salvataggio e il ripristino.

tf.keras.Model.save_weights salva un checkpoint TensorFlow.

 net.save_weights('easy_checkpoint')
 

Scrivere checkpoint

Lo stato persistente di un modello TensorFlow è memorizzato in oggetti tf.Variable . Questi possono essere costruiti direttamente, ma spesso vengono creati tramite API di alto livello come tf.keras.layers o tf.keras.Model .

Il modo più semplice per gestire le variabili è collegarle agli oggetti Python, quindi fare riferimento a tali oggetti.

Le sottoclassi di tf.train.Checkpoint , tf.keras.layers.Layer e tf.keras.Model traccia automaticamente delle variabili assegnate ai loro attributi. L'esempio seguente costruisce un modello lineare semplice, quindi scrive checkpoint che contengono valori per tutte le variabili del modello.

Puoi facilmente salvare un checkpoint del modello con Model.save_weights

Checkpoint manuale

Impostare

Per aiutare a dimostrare tutte le funzionalità di tf.train.Checkpoint definire un set di dati giocattolo e un passo di ottimizzazione:

 def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
 
 def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss
 

Crea gli oggetti del checkpoint

Per creare manualmente un checkpoint è necessario un oggetto tf.train.Checkpoint . Dove gli oggetti che si desidera controllare sono impostati come attributi sull'oggetto.

Un tf.train.CheckpointManager può anche essere utile per la gestione di più checkpoint.

 opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
 

Addestra e controlla il modello

Il ciclo di formazione seguente crea un'istanza del modello e di un ottimizzatore, quindi li raccoglie in un oggetto tf.train.Checkpoint . Chiama la fase di addestramento in un ciclo su ogni batch di dati e scrive periodicamente checkpoint su disco.

 def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
 
 train_and_checkpoint(net, manager)
 
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 28.06
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 21.47
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 14.93
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 8.50
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.70

Ripristina e continua l'allenamento

Dopo il primo puoi passare un nuovo modello e gestore, ma l'addestramento al ritiro esattamente da dove avevi interrotto:

 opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
 
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.05
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.85
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.49
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.33
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.15

L'oggetto tf.train.CheckpointManager elimina i vecchi checkpoint. Sopra è configurato per mantenere solo i tre checkpoint più recenti.

 print(manager.checkpoints)  # List the three remaining checkpoints
 
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Questi percorsi, ad esempio './tf_ckpts/ckpt-10' , non sono file sul disco. Sono invece prefissi per un file index e uno o più file di dati che contengono i valori delle variabili. Questi prefissi sono raggruppati in un unico file di checkpoint ( './tf_ckpts/checkpoint' ) in cui CheckpointManager salva il suo stato.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00001-of-00002
ckpt-10.data-00000-of-00002  ckpt-8.index
ckpt-10.data-00001-of-00002  ckpt-9.data-00000-of-00002
ckpt-10.index            ckpt-9.data-00001-of-00002
ckpt-8.data-00000-of-00002   ckpt-9.index

Meccanica di caricamento

TensorFlow abbina le variabili ai valori checkpoint attraversando un grafico diretto con i bordi con nome, a partire dall'oggetto da caricare. I nomi dei bordi in genere derivano dai nomi degli attributi negli oggetti, ad esempio "l1" in self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint utilizza i nomi degli argomenti delle parole chiave, come nel "step" in tf.train.Checkpoint(step=...) .

Il grafico delle dipendenze dell'esempio sopra è simile al seguente:

Visualizzazione del grafico delle dipendenze per il ciclo di formazione di esempio

Con l'ottimizzatore in rosso, le variabili regolari in blu e le variabili di slot dell'ottimizzatore in arancione. Gli altri nodi, ad esempio che rappresentano tf.train.Checkpoint , sono neri.

Le variabili di slot fanno parte dello stato dell'ottimizzatore, ma vengono create per una variabile specifica. Ad esempio, i bordi 'm' sopra corrispondono al momento, che l'ottimizzatore Adam traccia per ogni variabile. Le variabili di slot vengono salvate in un checkpoint solo se la variabile e l'ottimizzatore vengono entrambi salvati, quindi i bordi tratteggiati.

La chiamata di restore() su un oggetto tf.train.Checkpoint coda i restauri richiesti, ripristinando i valori delle variabili non appena esiste un percorso corrispondente dall'oggetto Checkpoint . Ad esempio, possiamo caricare solo il bias dal modello che abbiamo definito sopra ricostruendo un percorso attraverso la rete e il layer.

 to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
 
[0. 0. 0. 0. 0.]
[1.5749372 3.2779367 2.3600516 4.620399  3.9607847]

Il grafico delle dipendenze per questi nuovi oggetti è un sottografo molto più piccolo del checkpoint più grande che abbiamo scritto sopra. Include solo il bias e un contatore di salvataggio che tf.train.Checkpoint utilizza per numerare i checkpoint.

Visualizzazione di un sottografo per la variabile di polarizzazione

restore() restituisce un oggetto status, che ha asserzioni opzionali. Tutti gli oggetti che abbiamo creato nel nostro nuovo Checkpoint sono stati ripristinati, quindi passa status.assert_existing_objects_matched() .

 status.assert_existing_objects_matched()
 
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fe44ad109b0>

Ci sono molti oggetti nel checkpoint che non corrispondono, incluso il kernel del layer e le variabili dell'ottimizzatore. status.assert_consumed() passa solo se il checkpoint e il programma corrispondono esattamente e genererebbe un'eccezione qui.

Restauri ritardati

Layer oggetti Layer in TensorFlow possono ritardare la creazione di variabili alla loro prima chiamata, quando sono disponibili forme di input. Ad esempio, la forma del kernel di un layer Dense dipende sia dalle forme di input che di output del layer, quindi la forma di output richiesta come argomento del costruttore non è abbastanza informazioni per creare la variabile da sola. Poiché la chiamata a un Layer legge anche il valore della variabile, deve avvenire un ripristino tra la creazione della variabile e il suo primo utilizzo.

Per supportare questo linguaggio, ripristina le code tf.train.Checkpoint che non hanno ancora una variabile corrispondente.

 delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
 
[[0. 0. 0. 0. 0.]]
[[4.733721  4.6473064 4.9289813 4.7712607 4.958914 ]]

Ispezione manuale dei punti di controllo

tf.train.list_variables elenca le chiavi del checkpoint e le forme delle variabili in un checkpoint. Le chiavi del punto di arresto sono percorsi nel grafico visualizzato sopra.

 tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
 
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

Tracciamento di elenco e dizionario

Come per le assegnazioni dirette di attributi come self.l1 = tf.keras.layers.Dense(5) , l'assegnazione di elenchi e dizionari agli attributi self.l1 = tf.keras.layers.Dense(5) loro contenuto.

 save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
 

È possibile notare oggetti wrapper per elenchi e dizionari. Questi wrapper sono versioni checkpointable delle strutture dati sottostanti. Proprio come il caricamento basato sugli attributi, questi wrapper ripristinano il valore di una variabile non appena viene aggiunto al contenitore.

 restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
 
ListWrapper([])

Lo stesso tracciamento viene applicato automaticamente alle sottoclassi di tf.keras.Model e può essere utilizzato ad esempio per tenere traccia degli elenchi di livelli.

Salvataggio di checkpoint basati su oggetti con Estimator

Vedi la guida allo stimatore .

Gli stimatori per impostazione predefinita salvano i checkpoint con nomi di variabili anziché il grafico degli oggetti descritto nelle sezioni precedenti. tf.train.Checkpoint accetterà checkpoint basati sul nome, ma i nomi delle variabili possono cambiare quando si spostano parti di un modello al di fuori del model_fn dello model_fn . Il salvataggio di checkpoint basati su oggetti semplifica l'addestramento di un modello all'interno di uno stimatore e quindi lo utilizza all'esterno di uno.

 import tensorflow.compat.v1 as tf_compat
 
 def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
 
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1666: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.5144663, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 37.124985.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fe44a8d10b8>

tf.train.Checkpoint può quindi caricare i checkpoint dello stimatore dal suo model_dir .

 opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
 
10

Sommario

Gli oggetti TensorFlow forniscono un semplice meccanismo automatico per salvare e ripristinare i valori delle variabili che utilizzano.