RSVP для вашего местного мероприятия TensorFlow Everywhere сегодня!
Эта страница переведена с помощью Cloud Translation API.
Switch to English

Контрольные точки обучения

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Фраза «Сохранение модели TensorFlow» обычно означает одно из двух:

  1. Контрольно-пропускные пункты, ИЛИ
  2. SavedModel.

Контрольные точки фиксируют точное значение всех параметров ( tf.Variable объектов), используемых моделью. Контрольные точки не содержат никакого описания вычислений, определенных моделью, и поэтому обычно полезны только тогда, когда доступен исходный код, который будет использовать сохраненные значения параметров.

С другой стороны, формат SavedModel включает сериализованное описание вычислений, определенных моделью, в дополнение к значениям параметров (контрольная точка). Модели в этом формате не зависят от исходного кода, создавшего модель. Таким образом, они подходят для развертывания с помощью TensorFlow Serving, TensorFlow Lite, TensorFlow.js или программ на других языках программирования (C, C ++, Java, Go, Rust, C # и т. Д. API TensorFlow).

В этом руководстве рассматриваются API-интерфейсы для записи и чтения контрольных точек.

Настроить

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Сохранение из обучающих API tf.keras

См. Руководство tf.keras по сохранению и восстановлению.

tf.keras.Model.save_weights сохраняет контрольную tf.keras.Model.save_weights .

net.save_weights('easy_checkpoint')

Написание контрольных точек

Постоянное состояние модели tf.Variable хранится в объектах tf.Variable . Их можно создать напрямую, но часто они создаются с помощью высокоуровневых API, таких какtf.keras.layers или tf.keras.Model .

Самый простой способ управлять переменными - присоединить их к объектам Python и затем ссылаться на эти объекты.

Подклассы tf.train.Checkpoint , tf.keras.layers.Layer и tf.keras.Model автоматически отслеживают переменные, назначенные их атрибутам. В следующем примере строится простая линейная модель, а затем записываются контрольные точки, которые содержат значения для всех переменных модели.

Вы можете легко сохранить контрольную Model.save_weights модели с помощью Model.save_weights .

Ручная установка контрольных точек

Настроить

Чтобы продемонстрировать все возможности tf.train.Checkpoint , определите набор данных игрушки и шаг оптимизации:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Создайте объекты контрольной точки

Используйте объект tf.train.Checkpoint чтобы вручную создать контрольную точку, где объекты, которые вы хотите проверить, устанавливаются как атрибуты объекта.

tf.train.CheckpointManager также может быть полезен для управления несколькими контрольными tf.train.CheckpointManager .

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Обучите и проверьте модель

Следующий цикл обучения создает экземпляр модели и оптимизатора, а затем собирает их в объект tf.train.Checkpoint . Он вызывает шаг обучения в цикле для каждого пакета данных и периодически записывает контрольные точки на диск.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 30.42
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.83
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 17.27
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.81
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.74

Восстановить и продолжить обучение

После первого цикла обучения вы можете пройти новую модель и нового менеджера, но продолжить обучение с того места, где вы остановились:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.85
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.87
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.71
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.46
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.21

Объект tf.train.CheckpointManager удаляет старые контрольные точки. Выше он настроен на сохранение только трех последних контрольных точек.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Эти пути, например './tf_ckpts/ckpt-10' , не являются файлами на диске. Вместо этого они представляют собой префиксы для index файла и одного или нескольких файлов данных, содержащих значения переменных. Эти префиксы сгруппированы в один файл checkpoint ( './tf_ckpts/checkpoint' ), где CheckpointManager сохраняет свое состояние.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Механика загрузки

TensorFlow сопоставляет переменные со значениями контрольных точек, просматривая ориентированный граф с именованными ребрами, начиная с загружаемого объекта. Имена self.l1 = tf.keras.layers.Dense(5) обычно происходят от имен атрибутов в объектах, например "l1" в self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint использует свои имена аргументов ключевого слова, как в "step" в tf.train.Checkpoint(step=...) .

График зависимостей из приведенного выше примера выглядит так:

Визуализация графа зависимостей для примера обучающего цикла

Оптимизатор отображается красным цветом, обычные переменные - синим, а переменные слота оптимизатора - оранжевым. Другие узлы - например, представляющие tf.train.Checkpoint - tf.train.Checkpoint черным цветом.

Переменные слота являются частью состояния оптимизатора, но создаются для определенной переменной. Например, рёбра 'm' выше соответствуют импульсу, который оптимизатор Adam отслеживает для каждой переменной. Переменные слота сохраняются в контрольной точке только в том случае, если и переменная, и оптимизатор будут сохранены, то есть пунктирные края.

Вызов restore на через tf.train.Checkpoint объекты очередей запрошенных реставрации, восстановление значения переменного , как только есть соответствующий путь от Checkpoint объекта. Например, вы можете загрузить только смещение из модели, которую вы определили выше, реконструировав один путь к ней через сеть и слой.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.831489  3.7156947 2.5892444 3.8669944 4.749503 ]

Граф зависимостей для этих новых объектов - это гораздо меньший подграф более крупной контрольной точки, которую вы написали выше. Он включает только смещение и счетчик сохранения, которые tf.train.Checkpoint использует для нумерации контрольных точек.

Визуализация подграфа для переменной смещения

restore возвращает объект статуса, который имеет необязательные утверждения. Все объекты, созданные в новой Checkpoint status.assert_existing_objects_matched , были восстановлены, поэтому status.assert_existing_objects_matched проходит.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1644447b70>

В контрольной точке есть много объектов, которые не совпадают, включая ядро ​​уровня и переменные оптимизатора. status.assert_consumed проходит, только если контрольная точка и программа точно совпадают, и вызовет здесь исключение.

Отсроченные реставрации

Объекты Layer в TensorFlow могут отложить создание переменных до их первого вызова, когда доступны входные формы. Например, форма ядра Dense слоя зависит как от входных, так и выходных форм слоя, поэтому выходная форма, требуемая в качестве аргумента конструктора, не является достаточной информацией для создания переменной самостоятельно. Поскольку при вызове Layer также считывается значение переменной, восстановление должно происходить между созданием переменной и ее первым использованием.

Чтобы поддерживать эту идиому, очереди tf.train.Checkpoint восстанавливают, для которых еще нет соответствующей переменной.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5719748 4.6099544 4.931875  4.836442  4.8496275]]

Осмотр контрольных точек вручную

tf.train.load_checkpoint возвращает CheckpointReader который предоставляет доступ более низкого уровня к содержимому контрольной точки. Он содержит сопоставления ключа каждой переменной с формой и dtype для каждой переменной в контрольной точке. Ключ переменной - это путь к объекту, как на графиках выше.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Итак, если вас интересует значение net.l1.kernel вы можете получить значение с помощью следующего кода:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Он также предоставляет метод get_tensor позволяющий проверять значение переменной:

reader.get_tensor(key)
array([[4.5719748, 4.6099544, 4.931875 , 4.836442 , 4.8496275]],
      dtype=float32)

Отслеживание списков и словарей

Как и при прямом назначении атрибутов, например self.l1 = tf.keras.layers.Dense(5) , назначение списков и словарей атрибутам будет отслеживать их содержимое.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Вы можете заметить объекты-оболочки для списков и словарей. Эти оболочки являются версиями базовых структур данных с возможностью проверки. Как и при загрузке на основе атрибутов, эти оболочки восстанавливают значение переменной, как только она добавляется в контейнер.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Такое же отслеживание автоматически применяется к подклассам tf.keras.Model и может использоваться, например, для отслеживания списков слоев.

Резюме

Объекты TensorFlow предоставляют простой автоматический механизм для сохранения и восстановления значений переменных, которые они используют.