Pontos de verificação de treinamento

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

A frase "Salvando um modelo do TensorFlow" normalmente significa uma de duas coisas:

  1. Pontos de verificação, OU
  2. SavedModel.

Os pontos de verificação capturam o valor exato de todos os parâmetros (objetos tf.Variable ) usados ​​por um modelo. Os pontos de verificação não contêm nenhuma descrição do cálculo definido pelo modelo e, portanto, normalmente só são úteis quando o código-fonte que usará os valores de parâmetro salvos está disponível.

O formato SavedModel, por outro lado, inclui uma descrição serializada do cálculo definido pelo modelo, além dos valores de parâmetro (ponto de verificação). Os modelos neste formato são independentes do código-fonte que criou o modelo. Portanto, eles são adequados para implantação via TensorFlow Serving, TensorFlow Lite, TensorFlow.js ou programas em outras linguagens de programação (C, C ++, Java, Go, Rust, C # etc. APIs TensorFlow).

Este guia cobre APIs para escrever e ler pontos de verificação.

Configurar

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Salvando de APIs de treinamento tf.keras

Consulte o guia tf.keras sobre como salvar e restaurar.

tf.keras.Model.save_weights salva um ponto de verificação do TensorFlow.

net.save_weights('easy_checkpoint')

Escrevendo pontos de verificação

O estado persistente de um modelo do TensorFlow é armazenado em objetos tf.Variable . Eles podem ser construídos diretamente, mas geralmente são criados por meio de APIs de alto nível, comotf.keras.layers ou tf.keras.Model .

A maneira mais fácil de gerenciar variáveis ​​é anexando-as a objetos Python e, em seguida, referenciando esses objetos.

As subclasses de tf.train.Checkpoint , tf.keras.layers.Layer e tf.keras.Model rastreiam automaticamente as variáveis ​​atribuídas a seus atributos. O exemplo a seguir constrói um modelo linear simples e, em seguida, grava pontos de verificação que contêm valores para todas as variáveis ​​do modelo.

Você pode salvar facilmente um ponto de verificação de modelo com Model.save_weights .

Ponto de verificação manual

Configurar

Para ajudar a demonstrar todos os recursos de tf.train.Checkpoint , defina um conjunto de dados de brinquedo e uma etapa de otimização:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Crie os objetos de ponto de verificação

Use um objeto tf.train.Checkpoint para criar manualmente um ponto de verificação, onde os objetos que você deseja marcar são definidos como atributos no objeto.

Um tf.train.CheckpointManager também pode ser útil para gerenciar vários pontos de verificação.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Treinar e verificar o modelo

O seguinte loop de treinamento cria uma instância do modelo e de um otimizador e os reúne em um objeto tf.train.Checkpoint . Ele chama a etapa de treinamento em um loop em cada lote de dados e grava pontos de verificação no disco periodicamente.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.00
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 22.42
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.86
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 9.40
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.20

Restaurar e continuar o treinamento

Após o primeiro ciclo de treinamento, você pode passar por um novo modelo e gerente, mas retome o treinamento exatamente de onde parou:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.19
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.66
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.90
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.32
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.34

O objeto tf.train.CheckpointManager exclui pontos de verificação antigos. Acima está configurado para manter apenas os três pontos de verificação mais recentes.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Esses caminhos, por exemplo, './tf_ckpts/ckpt-10' , não são arquivos no disco. Em vez disso, eles são prefixos para um arquivo de index e um ou mais arquivos de dados que contêm os valores das variáveis. Esses prefixos são agrupados em um único arquivo de checkpoint ( './tf_ckpts/checkpoint' ) onde o CheckpointManager salva seu estado.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Carregando mecânica

O TensorFlow combina variáveis ​​com valores de checkpoint, percorrendo um gráfico direcionado com bordas nomeadas, começando a partir do objeto que está sendo carregado. Os nomes de borda normalmente vêm de nomes de atributos em objetos, por exemplo, o "l1" em self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint usa seus nomes de argumento de palavra-chave, como na "step" em tf.train.Checkpoint(step=...) .

O gráfico de dependência do exemplo acima se parece com isto:

Visualização do gráfico de dependência para o exemplo de loop de treinamento

O otimizador está em vermelho, as variáveis ​​regulares estão em azul e as variáveis ​​de slot do otimizador estão em laranja. Os outros nós - por exemplo, representando o tf.train.Checkpoint - estão em preto.

As variáveis ​​de slot fazem parte do estado do otimizador, mas são criadas para uma variável específica. Por exemplo, as arestas 'm' acima correspondem ao momento, que o otimizador de Adam rastreia para cada variável. As variáveis ​​de slot são salvas em um ponto de verificação apenas se a variável e o otimizador forem salvos, portanto, as bordas tracejadas.

Chamar restore em um objeto tf.train.Checkpoint enfileira as restaurações solicitadas, restaurando valores de variáveis ​​assim que houver um caminho correspondente do objeto Checkpoint . Por exemplo, você pode carregar apenas o bias do modelo definido acima, reconstruindo um caminho para ele através da rede e da camada.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.2704186 3.0526643 3.8114467 3.4453893 4.2802196]

O gráfico de dependência para esses novos objetos é um subgráfico muito menor do maior ponto de verificação que você escreveu acima. Inclui apenas a polarização e um contador de salvamento que tf.train.Checkpoint usa para numerar os pontos de verificação.

Visualização de um subgráfico para a variável de polarização

restore retorna um objeto de status, que possui asserções opcionais. Todos os objetos criados no novo Checkpoint foram restaurados, então status.assert_existing_objects_matched aprovado.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2a4cbccb38>

Existem muitos objetos no ponto de verificação que não corresponderam, incluindo o kernel da camada e as variáveis ​​do otimizador. status.assert_consumed só passa se o ponto de verificação e o programa correspondem exatamente, e lançaria uma exceção aqui.

Restaurações atrasadas

Objetos de Layer no TensorFlow podem atrasar a criação de variáveis ​​para sua primeira chamada, quando as formas de entrada estão disponíveis. Por exemplo, a forma do kernel de uma camada Dense depende das formas de entrada e saída da camada e, portanto, a forma de saída exigida como um argumento do construtor não é informação suficiente para criar a variável por conta própria. Visto que chamar uma Layer também lê o valor da variável, uma restauração deve acontecer entre a criação da variável e seu primeiro uso.

Para oferecer suporte a esse idioma, tf.train.Checkpoint enfileira restaurações que ainda não têm uma variável correspondente.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6544    4.6866627 4.729344  4.9574785 4.8010526]]

Inspecionando manualmente os pontos de verificação

tf.train.load_checkpoint retorna um CheckpointReader que fornece acesso de nível inferior ao conteúdo do ponto de verificação. Ele contém mapeamentos da chave de cada variável, para a forma e o tipo de cada variável no ponto de verificação. A chave de uma variável é seu caminho de objeto, como nos gráficos exibidos acima.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Portanto, se você estiver interessado no valor de net.l1.kernel poderá obter o valor com o seguinte código:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Ele também fornece um método get_tensor que permite inspecionar o valor de uma variável:

reader.get_tensor(key)
array([[4.6544   , 4.6866627, 4.729344 , 4.9574785, 4.8010526]],
      dtype=float32)

Rastreamento de lista e dicionário

Assim como acontece com atribuições diretas de atributos como self.l1 = tf.keras.layers.Dense(5) , atribuir listas e dicionários a atributos rastreará seus conteúdos.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Você pode notar objetos de invólucro para listas e dicionários. Esses wrappers são versões verificáveis ​​das estruturas de dados subjacentes. Assim como o carregamento baseado em atributo, esses wrappers restauram o valor de uma variável assim que ela é adicionada ao contêiner.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

O mesmo rastreamento é aplicado automaticamente às subclasses de tf.keras.Model e pode ser usado, por exemplo, para rastrear listas de camadas.

Resumo

Os objetos do TensorFlow fornecem um mecanismo automático fácil para salvar e restaurar os valores das variáveis ​​que usam.