Pontos de verificação de treinamento

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

A frase "Salvando um modelo do TensorFlow" normalmente significa uma de duas coisas:

  1. Pontos de verificação, OU
  2. SavedModel.

Checkpoints capturar o valor exato de todos os parâmetros ( tf.Variable objetos) usadas por um modelo. Os pontos de verificação não contêm nenhuma descrição do cálculo definido pelo modelo e, portanto, normalmente só são úteis quando o código-fonte que usará os valores de parâmetro salvos está disponível.

O formato SavedModel, por outro lado, inclui uma descrição serializada do cálculo definido pelo modelo, além dos valores de parâmetro (ponto de verificação). Os modelos neste formato são independentes do código-fonte que criou o modelo. Portanto, eles são adequados para implantação por meio do TensorFlow Serving, TensorFlow Lite, TensorFlow.js ou programas em outras linguagens de programação (C, C ++, Java, Go, Rust, C # etc. APIs do TensorFlow).

Este guia cobre APIs para escrever e ler pontos de verificação.

Configurar

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Salvando a partir tf.keras APIs de treinamento

Veja as tf.keras orientar sobre salvar e restaurar.

tf.keras.Model.save_weights poupa um posto de controle TensorFlow.

net.save_weights('easy_checkpoint')

Escrevendo pontos de verificação

O estado persistente de um modelo TensorFlow é armazenado em tf.Variable objetos. Estes podem ser construídos diretamente, mas são muitas vezes criados por meio de APIs de alto nível como tf.keras.layers ou tf.keras.Model .

A maneira mais fácil de gerenciar variáveis ​​é anexando-as a objetos Python e, em seguida, referenciando esses objetos.

Subclasses de tf.train.Checkpoint , tf.keras.layers.Layer e tf.keras.Model rastrear automaticamente variáveis atribuídas aos seus atributos. O exemplo a seguir constrói um modelo linear simples e, em seguida, grava pontos de verificação que contêm valores para todas as variáveis ​​do modelo.

Você pode facilmente salvar um modelo de ponto de verificação com Model.save_weights .

Ponto de verificação manual

Configurar

Para ajudar a demonstrar todas as características de tf.train.Checkpoint , definir um conjunto de dados de brinquedo e etapa de otimização:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Crie os objetos de checkpoint

Use um tf.train.Checkpoint objeto para criar manualmente um posto de controle, onde os objetos que deseja checkpoint são definidos como atributos de objeto.

A tf.train.CheckpointManager também pode ser útil para gerenciar vários postos de controle.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Treinar e verificar o modelo

O ciclo seguinte formação cria uma instância do modelo e de um otimizador, em seguida, reúne-los em um tf.train.Checkpoint objeto. Ele chama a etapa de treinamento em um loop em cada lote de dados e grava pontos de verificação no disco periodicamente.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.77
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.18
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 16.62
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.16
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.09

Restaurar e continuar o treinamento

Após o primeiro ciclo de treinamento, você pode passar por um novo modelo e gerente, mas retome o treinamento exatamente de onde parou:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.33
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.90
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.62
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.27
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.22

O tf.train.CheckpointManager objeto exclui checkpoints idade. Acima está configurado para manter apenas os três pontos de verificação mais recentes.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Esses caminhos, por exemplo './tf_ckpts/ckpt-10' , não são arquivos no disco. Em vez disso, são prefixos para um index arquivo e um ou mais arquivos de dados que contêm os valores das variáveis. Esses prefixos são agrupados em um único checkpoint de arquivos ( './tf_ckpts/checkpoint' ) onde o CheckpointManager salva seu estado.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Carregando mecânica

O TensorFlow combina variáveis ​​com valores de checkpoint, percorrendo um gráfico direcionado com bordas nomeadas, começando a partir do objeto que está sendo carregado. Nomes de Borda normalmente vêm de nomes de atributos em objetos, por exemplo, o "l1" no self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint usa seus nomes de argumentos de palavra-chave, como no "step" na tf.train.Checkpoint(step=...) .

O gráfico de dependência do exemplo acima se parece com isto:

Visualização do gráfico de dependência para o loop de treinamento de exemplo

O otimizador está em vermelho, as variáveis ​​regulares estão em azul e as variáveis ​​de slot do otimizador estão em laranja. Os outros nós, por exemplo, que representam o tf.train.Checkpoint -são em preto.

As variáveis ​​de slot fazem parte do estado do otimizador, mas são criadas para uma variável específica. Por exemplo, as 'm' bordas acima correspondem a impulso, que as faixas optimizer Adam para cada variável. As variáveis ​​de slot são salvas em um ponto de verificação apenas se a variável e o otimizador forem salvos, portanto, as bordas tracejadas.

Chamando restore em um tf.train.Checkpoint objeto filas As restaurações solicitados, restaurando valores de variáveis, logo que há um caminho correspondente do Checkpoint objeto. Por exemplo, você pode carregar apenas o viés do modelo definido acima, reconstruindo um caminho para ele através da rede e da camada.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[1.9851578 3.6375327 2.9331083 3.8130412 4.778274 ]

O gráfico de dependência para esses novos objetos é um subgráfico muito menor do maior ponto de verificação que você escreveu acima. Ele inclui apenas o viés e uma economia de contador que tf.train.Checkpoint usa para checkpoints numéricas.

Visualização de um subgráfico para a variável de polarização

restore retorna um objeto de status, que tem afirmações opcionais. Todos os objetos criados no novo Checkpoint foram restaurados, assim status.assert_existing_objects_matched passes.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f027807e050>

Existem muitos objetos no ponto de verificação que não corresponderam, incluindo o kernel da camada e as variáveis ​​do otimizador. status.assert_consumed passa apenas se o ponto de verificação eo jogo programa exatamente, e iria lançar uma exceção aqui.

Restaurações atrasadas

Layer objetos em TensorFlow pode atrasar a criação de variáveis para a sua primeira convocação, quando as formas de entrada estão disponíveis. Por exemplo, a forma de um Dense núcleo da camada depende das duas formas de entrada e saída da camada, e assim a forma de saída necessária como um argumento construtor não é informação suficiente para criar a variável por conta própria. Desde chamar uma Layer também lê o valor da variável, uma restauração deve acontecer entre a criação da variável e seu primeiro uso.

Para suportar este idioma, tf.train.Checkpoint filas restaurações que ainda não têm uma variável correspondente.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6800494 4.607369  4.8321466 4.816245  4.8435326]]

Inspecionando manualmente os pontos de verificação

tf.train.load_checkpoint retorna um CheckpointReader que dá acesso nível inferior ao conteúdo de ponto de verificação. Ele contém mapeamentos da chave de cada variável, para a forma e o tipo de cada variável no ponto de verificação. A chave de uma variável é seu caminho de objeto, como nos gráficos exibidos acima.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Então, se você está interessado no valor de net.l1.kernel você pode obter o valor com o seguinte código:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Ele também fornece uma get_tensor método que lhe permite inspecionar o valor de uma variável:

reader.get_tensor(key)
array([[4.6800494, 4.607369 , 4.8321466, 4.816245 , 4.8435326]],
      dtype=float32)

Rastreamento de lista e dicionário

Tal como acontece com as atribuições directos de atributos como self.l1 = tf.keras.layers.Dense(5) , a atribuição de listas e dicionários de atributos irá rastrear os seus conteúdos.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Você pode notar objetos de invólucro para listas e dicionários. Esses wrappers são versões verificáveis ​​das estruturas de dados subjacentes. Assim como o carregamento baseado em atributo, esses wrappers restauram o valor de uma variável assim que ela é adicionada ao contêiner.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

O mesmo acompanhamento é aplicada automaticamente às subclasses de tf.keras.Model , e pode ser usado por exemplo, para rastrear listas de camadas.

Resumo

Os objetos do TensorFlow fornecem um mecanismo automático fácil para salvar e restaurar os valores das variáveis ​​que usam.