Эта страница была переведа с помощью Cloud Translation API.
Switch to English

Контрольные точки обучения

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Фраза «Сохранение модели TensorFlow» обычно означает одно из двух:

  1. Контрольно-пропускные пункты, ИЛИ
  2. SavedModel.

Контрольные точки фиксируют точное значение всех параметров ( tf.Variable объектов), используемых моделью. Контрольные точки не содержат никакого описания вычислений, определенных моделью, и поэтому обычно полезны только тогда, когда доступен исходный код, который будет использовать сохраненные значения параметров.

С другой стороны, формат SavedModel включает сериализованное описание вычислений, определенных моделью, в дополнение к значениям параметров (контрольная точка). Модели в этом формате не зависят от исходного кода, создавшего модель. Таким образом, они подходят для развертывания через TensorFlow Serving, TensorFlow Lite, TensorFlow.js или программы на других языках программирования (C, C ++, Java, Go, Rust, C # и т. Д. API TensorFlow).

В этом руководстве рассматриваются API-интерфейсы для записи и чтения контрольных точек.

Настроить

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Сохранение из обучающих API tf.keras

См. Руководство tf.keras по сохранению и восстановлению.

tf.keras.Model.save_weights сохраняет контрольную tf.keras.Model.save_weights .

net.save_weights('easy_checkpoint')

Написание контрольных точек

Постоянное состояние модели tf.Variable хранится в объектах tf.Variable . Их можно создать напрямую, но часто они создаются с помощью высокоуровневых API, таких как tf.keras.layers или tf.keras.Model .

Самый простой способ управлять переменными - присоединить их к объектам Python и затем ссылаться на эти объекты.

Подклассы tf.train.Checkpoint , tf.keras.layers.Layer и tf.keras.Model автоматически отслеживают переменные, назначенные их атрибутам. В следующем примере строится простая линейная модель, а затем записываются контрольные точки, которые содержат значения для всех переменных модели.

Вы можете легко сохранить контрольную Model.save_weights модели с помощью Model.save_weights

Ручная установка контрольных точек

Настроить

Чтобы продемонстрировать все возможности tf.train.Checkpoint определите набор данных игрушки и шаг оптимизации:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Создайте объекты контрольной точки

Чтобы вручную создать контрольную точку, вам понадобится объект tf.train.Checkpoint . Где объекты, которые вы хотите проверить, устанавливаются как атрибуты объекта.

tf.train.CheckpointManager также может быть полезен для управления несколькими контрольными tf.train.CheckpointManager .

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Обучите и проверьте модель

Следующий цикл обучения создает экземпляр модели и оптимизатора, а затем собирает их в объект tf.train.Checkpoint . Он вызывает шаг обучения в цикле для каждого пакета данных и периодически записывает контрольные точки на диск.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 28.15
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 21.56
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.00
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 8.52
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.25

Восстановить и продолжить обучение

После первого вы можете пройти новую модель и менеджера, но начните обучение именно с того места, где вы остановились:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.76
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.65
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.51
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.25

Объект tf.train.CheckpointManager удаляет старые контрольные точки. Выше он настроен на сохранение только трех последних контрольных точек.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Эти пути, например './tf_ckpts/ckpt-10' , не являются файлами на диске. Вместо этого они представляют собой префиксы для index файла и одного или нескольких файлов данных, содержащих значения переменных. Эти префиксы сгруппированы в один файл checkpoint ( './tf_ckpts/checkpoint' ), где CheckpointManager сохраняет свое состояние.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Механика загрузки

TensorFlow сопоставляет переменные со значениями контрольных точек, просматривая ориентированный граф с именованными ребрами, начиная с загружаемого объекта. Имена self.l1 = tf.keras.layers.Dense(5) обычно берутся из имен атрибутов в объектах, например "l1" в self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint использует свои имена аргументов ключевого слова, как в "step" в tf.train.Checkpoint(step=...) .

График зависимостей из приведенного выше примера выглядит так:

Визуализация графа зависимостей для примера обучающего цикла

Оптимизатор выделен красным, обычные переменные - синим, а переменные слота оптимизатора - оранжевым. Другие узлы, например, представляющие tf.train.Checkpoint , черные.

Переменные слота являются частью состояния оптимизатора, но создаются для определенной переменной. Например, рёбра 'm' выше соответствуют импульсу, который оптимизатор Adam отслеживает для каждой переменной. Переменные слота сохраняются в контрольной точке только в том случае, если и переменная, и оптимизатор будут сохранены, то есть пунктирные края.

Вызов tf.train.Checkpoint restore() для объекта tf.train.Checkpoint очередь запрошенные восстановления, восстанавливая значения переменных, как только будет найден соответствующий путь от объекта Checkpoint . Например, мы можем загрузить только смещение из модели, которую мы определили выше, реконструировав один путь к ней через сеть и слой.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
[0. 0. 0. 0. 0.]
[3.5548685 2.8931093 2.3509905 3.5525272 4.017799 ]

Граф зависимостей для этих новых объектов - это гораздо меньший подграф более крупной контрольной точки, которую мы написали выше. Он включает только смещение и счетчик сохранения, которые tf.train.Checkpoint использует для нумерации контрольных точек.

Визуализация подграфа для переменной смещения

restore() возвращает объект статуса, который имеет необязательные утверждения. Все объекты, которые мы создали в нашей новой Checkpoint status.assert_existing_objects_matched() , были восстановлены, поэтому status.assert_existing_objects_matched() проходит.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fea0c3c3860>

В контрольной точке есть много объектов, которые не совпадают, включая ядро ​​уровня и переменные оптимизатора. status.assert_consumed() проходит, только если контрольная точка и программа точно совпадают, и вызовут здесь исключение.

Отсроченные реставрации

Объекты Layer в TensorFlow могут отложить создание переменных до их первого вызова, когда доступны входные формы. Например, форма ядра Dense слоя зависит как от входных, так и выходных форм слоя, поэтому выходная форма, требуемая в качестве аргумента конструктора, не является достаточной информацией для создания переменной самостоятельно. Поскольку при вызове Layer также считывается значение переменной, восстановление должно происходить между созданием переменной и ее первым использованием.

Чтобы поддерживать эту идиому, очереди tf.train.Checkpoint восстанавливают, для которых еще нет соответствующей переменной.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.453001  4.6668463 4.9372597 4.90143   4.9549575]]

Осмотр контрольных точек вручную

tf.train.list_variables перечисляет ключи контрольной точки и формы переменных в контрольной точке. Ключи контрольных точек - это пути на графике, показанном выше.

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

Отслеживание списков и словарей

Как и в случае прямого присвоения атрибутов, например self.l1 = tf.keras.layers.Dense(5) , назначение списков и словарей атрибутам будет отслеживать их содержимое.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Вы можете заметить объекты-оболочки для списков и словарей. Эти оболочки являются версиями базовых структур данных с возможностью проверки. Как и при загрузке на основе атрибутов, эти оболочки восстанавливают значение переменной, как только она добавляется в контейнер.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Такое же отслеживание автоматически применяется к подклассам tf.keras.Model и может использоваться, например, для отслеживания списков слоев.

Сохранение объектных контрольных точек с помощью оценщика

См. Руководство по оценке.

Оценщики по умолчанию сохраняют контрольные точки с именами переменных, а не с графом объектов, описанным в предыдущих разделах. tf.train.Checkpoint будет принимать контрольные точки на основе имен, но имена переменных могут измениться при перемещении частей модели за пределы model_fn оценщика. Сохранение контрольных точек на основе объектов упрощает обучение модели внутри оценщика, а затем использование ее вне его.

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.388644, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 34.98601.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fea648fbf60>

tf.train.Checkpoint может затем загрузить контрольные точки Оценщика из его model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Резюме

Объекты TensorFlow предоставляют простой автоматический механизм для сохранения и восстановления значений переменных, которые они используют.