Entrenar checkpoints

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar notebook

La frase "Saving a TensorFlow model" significa tipicamente una de las dos cosas: (1) Checkpoints, O (2) SavedModel.

Los Checkpoints capturan el valor exacto de todos los parametros (objetos tf.Variable) usados por un modelo. Los Checkpoints no almacenan ninguna descripcion del computo utilizado por el modelo. Por lo mismo, los checkpoints solo son utiles cuando el codigo que usara los parametros almacenados esta disponible.

Por otro lado, el formato SavedModel incluye una descripcion serializada del computo definido por el modelo ademas de los valores de los parametros (checkpoint). Con este formato, los modelos son independientes al codigo que creo el mismo. Por ende son idoneos para el despliegue de los modelos a traves de TensorFlos Serving, TensorFlow Lite, TensorFlow.js, o programas en otros lenguajes de programacion (las APIs de TensorFlow para C, C++, Java, Go, Rust, C# etc.)

Esta guia cubre las APIs para leer y escribir checkpoints.

Guardando de las APIs de entrenamiento de tf.keras

Pueden referirse a la guia para guardar y restaurar de tf.keras, NOTA: al momento esta en ingles.

tf.keras.Model.save_weights tambien permite la opcion de guardar en el formato TensorFlow checkpoint. Esta guia explica el formato a mayor detalle e introduce las APIs para administrar los checkpoints en bucles de entrenamiento personalizados.

Definir checkpoints manualmente

El estado persistente de un modelo de TensorFlow es almacenado en objectos tf.Variable. Estos objetos pueden ser construidos directamente, pero comunmente con creados mediante APIs de alto nivel tales como tf.keras.layers.

La manera mas sencilla de admistrar las variables es asociandolas a objetos de Python, y despues referenciando dichos objetos. Las subclases de tf.train.Checkpoint, tf.keras.layers.Layer, y tf.keras.Model rastrean automaticamente las variables asociadas a sus atributos. El ejemplo a contunuacion construye un modelo linear simple, y posteriormente escribe checkpoints que contienen valores para todas las variables del modelo.

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
class Net(tf.keras.Model):
  """Un Modelo Linear simple."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)

Este ejemplo necesita datos y un paso de optimizacion para poder ser ejecutable aunque esta guia no se trate de esos temas. El modelo entrenara por slices de un dataset en memoria.

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat(10).batch(2)
def train_step(net, example, optimizer):
  """Entrena `net` en `example` usando `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

El siguiente buckle crea una instancia del modelo y de un optimizer, despues los recolecta en un objeto tf.train.Checkpoint. Llama el paso de entrenamiento en un ciclo para cada batch de datos, y escribe periodicamente checkpoints en disco.

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
  print("Restaurado de {}".format(manager.latest_checkpoint))
else:
  print("Inicializando desde cero.")

for example in toy_dataset():
  loss = train_step(net, example, opt)
  ckpt.step.assign_add(1)
  if int(ckpt.step) % 10 == 0:
    save_path = manager.save()
    print("Checkpoint almacenado para el paso {}: {}".format(int(ckpt.step), save_path))
    print("loss {:1.2f}".format(loss.numpy()))
Inicializando desde cero.
Checkpoint almacenado para el paso 10: ./tf_ckpts/ckpt-1
loss 30.26
Checkpoint almacenado para el paso 20: ./tf_ckpts/ckpt-2
loss 23.68
Checkpoint almacenado para el paso 30: ./tf_ckpts/ckpt-3
loss 17.14
Checkpoint almacenado para el paso 40: ./tf_ckpts/ckpt-4
loss 10.71
Checkpoint almacenado para el paso 50: ./tf_ckpts/ckpt-5
loss 5.16

El snippet anterior inicializara aleatoriamente las variables del modelo en su primera corrida. Posterior a esta, reanudara el entrenamiendo en donde se quedo:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
  print("Restaurado de {}".format(manager.latest_checkpoint))
else:
  print("Inicializando desde cero")

for example in toy_dataset():
  loss = train_step(net, example, opt)
  ckpt.step.assign_add(1)
  if int(ckpt.step) % 10 == 0:
    save_path = manager.save()
    print("Checkpoint almacenado para el paso {}: {}".format(int(ckpt.step), save_path))
    print("loss {:1.2f}".format(loss.numpy()))
Restaurado de ./tf_ckpts/ckpt-5
Checkpoint almacenado para el paso 60: ./tf_ckpts/ckpt-6
loss 1.62
Checkpoint almacenado para el paso 70: ./tf_ckpts/ckpt-7
loss 1.00
Checkpoint almacenado para el paso 80: ./tf_ckpts/ckpt-8
loss 1.57
Checkpoint almacenado para el paso 90: ./tf_ckpts/ckpt-9
loss 0.89
Checkpoint almacenado para el paso 100: ./tf_ckpts/ckpt-10
loss 0.56

El objeto tf.train.CheckpointManager elimina checkpoints viejos. Arriba ha sido configurado para conservar los tres checkpoints mas recientes.

print(manager.checkpoints)  # Lista los tres checkpoints restantes
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Estos paths, e.g. './tf_ckpts/ckpt-10', no son archivos en disco. Son prefijos para un archivo tipo index y uno o mas archivos de datos que contienen los valores de las variables. Estos prefijos estan agrupados en un unico archivo de checkpoint ('./tf_ckpts/checkpoint') donde el CheckpointManager guarda su estado.

!ls ./tf_ckpts
checkpoint           ckpt-8.data-00001-of-00002
ckpt-10.data-00000-of-00002  ckpt-8.index
ckpt-10.data-00001-of-00002  ckpt-9.data-00000-of-00002
ckpt-10.index            ckpt-9.data-00001-of-00002
ckpt-8.data-00000-of-00002   ckpt-9.index

Mecanica de carga

TensorFlow hace coincider las variables con los valores almacenados como checkpoint atravesando un grafo dirigido con aristas nombradas, comenzando con el objeto que este siendo cargado. Los nombres de las aristas vienen de los nombres de los atributos en los objetos, por ejemplo el "l1" en self.l1 = tf.keras.layers.Dense(5). tf.train.Checkpoint usa sus argumentos de palabras clave como nombre, tal como el "step" en tf.train.Checkpoint(step=...).

El grafo de dependencias del ejemplo anterior se ve asi:

Visualization of the dependency graph for the example training loop

Con el optimizer en rojo, las variables regulares en azul, y variables slot del optimizer en naranja. Los otros nodos, por ejemplo el que representa el tf.train.Checkpoint, son negros.

Las variables Slot son parte del estado del optimizer, pero son creadas para una variable especifica. Por ejemplo las aristas 'm' de arriba corresponden a un momentum, los cuales son rastreados por el Adam's optimizer para cada variable. Las variables Slot solo son almacenadas en un checkpoint si la variable y el optimizer serian ambas almacenadas, por eso los aristas punteados.

La llamada restore() de un objeto tf.train.Checkpoint hace cola las restauraciones requeridas, restaurando los valores de las variables tan pronto como se encuentre un path correspondiente en el objeto Checkpoint. Por ejemplo podemos cargar solo el kernel del model que definimos anteriormente recosntruyendo un path a el mediante la red y la capa (layer).

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # Puros ceros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # Ahora obtenemos el valor restaurado
[0. 0. 0. 0. 0.]
[1.6114205 3.0766637 4.3280964 3.8758667 4.7540035]

El grafo de dependencias para estos objetos nuevos es un sub-grafo del checkpoint que escribimos anteriormente. Solo incluye el bias y un save counter que el tf.train.Checkpoint usa para enumerar los checkpoints.

Visualization of a subgraph for the bias variable

restore() regresa el estado del objeto, que tiene afirmaciones (assertions) opcionales. Todos los objetos que hemos creado en nuestro nuevo Checkpoint han sido restaurados, asi que status.assert_existing_objects_matched() pasa exitosamente.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2d007cdc88>

Hay muchos objetos en el checkpoint que no han sido emparejados, incluyendo el kernel de la capa y las variables del optimized. status.assert_consumed() solo pasa si el checkpoint y el programa empatan exactamente, y arrojara una excepcion en este caso.

Restauraciones retrasadas

Los objetos Layer en TensorFlow pueden retrasar la creacion de variables para su primera llamada, cuando las dimensiones de entrada estan disponibles. Por ejemplo, las dimensiones de un layer kernel Dense dependen tanto de las entradas de la capa como de las dimensiones de salida, y por ende solo la dimension de salida que es requerida como argumento de construccion no es suficiente informacion para la creacion de las variables. Como la llamada a Layer tambien lee el valor de la variable, una restauracion debe pasa entre las variables de creacion y su primer uso.

Para dar soporte a este idioma, tf.train.Checkpoint forma una lista de restauranciones que no tienen una vatiable que empate aun.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # No restaurado; siguen siendo ceros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restaurado
[[0. 0. 0. 0. 0.]]
[[4.756683  4.677118  4.688866  4.8417225 4.8637376]]

Revision manual de los checkpoints

tf.train.list_variables lista las checkpoint keys y las dimensiones de las variables en un checkpoint. Las Checkpoint keys son los paths del grafo mostrado anteriormente.

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

Rastreo de listas y diccionarios

List and dictionary tracking

Igual que en las asignaciones directas de atributos, e.g. self.l1 = tf.keras.layers.Dense(5), la asignacion de listas y diccionarios a atributos rastreara sus contenidos.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # No ha sido restaurado aun
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Puede notar que existen wrappers de objetos para listas y diccionarios. Estos wrappers pueden ser incluidos en versiones checkpoint de las estructuras de datos subyacientes. Asi como la carga basada en atributos, estos wrappers restauran el valor de una variable al momento de ser agregada al contenedor.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restaurar v1, del restore() de la celda anterior
assert 1. == v1.numpy()
ListWrapper([])

El mismo rastreo es aplicado automaticamente a subclases de tf.keras.Model, y puede ser usado para rastrear listas de capas por ejemplo.

Guardar checkponts basados en objetos con Estimator

Ver la guia a Estimator. NOTA: documentacion en ingles.

Los Estimators guardan checkpoints por default con nombres de variables en lugar de el ogjeto grafo descrito en las secciones anteriores. tf.train.Checkpoint aceptara ckeckpoints basadon en nombres, pero los nombres de las variables podrian cambiar al movr las pertes del modelo fuera del model_fn del Estimator. Guardar checkpoints basados en objetos facilita el entrenamiento de un modelo dentro de un Estimator y su posterior uso fuera de el.

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Decirle al Estimator gue guarde "ckpt" en un formato basado en objeto.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_eval_distribute': None, '_keep_checkpoint_max': 5, '_global_id_in_cluster': 0, '_num_ps_replicas': 0, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_save_summary_steps': 100, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_save_checkpoints_secs': 600, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_save_checkpoints_steps': None, '_num_worker_replicas': 1, '_device_fn': None, '_is_chief': True, '_service': None, '_train_distribute': None, '_model_dir': './tf_estimator_example/', '_protocol': None, '_task_type': 'worker', '_evaluation_master': '', '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f2ce8b38b70>, '_task_id': 0, '_tf_random_seed': None, '_master': ''}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_estimator/python/estimator/model_fn.py:337: scalar (from tensorflow.python.framework.tensor_shape) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.TensorShape([]).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/array_ops.py:1486: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:loss = 4.536913, step = 0
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Loss for final step: 37.50658.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f2ce8b38940>

tf.train.Checkpoint puede cargar los checkpoints del Estimator de su model_dir.

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # De est.train(..., steps=10)
10

Resumen

Los objetos de TensorFlow proveen un mecanismo facil y automatico para guardar y restaurar los valores de las variables que usan.