Mam pytanie? Połącz się ze społecznością na Forum TensorFlow Odwiedź Forum

Szkoleniowe punkty kontrolne

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło w serwisie GitHub Pobierz notatnik

Wyrażenie „Zapisywanie modelu TensorFlow” zazwyczaj oznacza jedną z dwóch rzeczy:

  1. Punkty kontrolne LUB
  2. SavedModel.

Punkty kontrolne rejestrują dokładną wartość wszystkich parametrów (obiektów tf.Variable ) używanych przez model. Punkty kontrolne nie zawierają opisu obliczeń zdefiniowanych przez model i dlatego są zwykle przydatne tylko wtedy, gdy dostępny jest kod źródłowy, który będzie używał zapisanych wartości parametrów.

Z drugiej strony format SavedModel zawiera serializowany opis obliczeń zdefiniowanych przez model oprócz wartości parametrów (punkt kontrolny). Modele w tym formacie są niezależne od kodu źródłowego, który utworzył model. Są zatem odpowiednie do wdrażania za pośrednictwem usług TensorFlow Serving, TensorFlow Lite, TensorFlow.js lub programów w innych językach programowania (C, C ++, Java, Go, Rust, C # itp. API TensorFlow).

Ten przewodnik obejmuje interfejsy API do pisania i odczytywania punktów kontrolnych.

Ustawiać

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Zapisywanie z tf.keras szkoleniowych interfejsów API

Zobacz przewodnik tf.keras dotyczący zapisywania i przywracania.

tf.keras.Model.save_weights zapisuje punkt kontrolny TensorFlow.

net.save_weights('easy_checkpoint')

Pisanie punktów kontrolnych

Trwały stan modelu TensorFlow jest przechowywany w obiektach tf.Variable . Można je konstruować bezpośrednio, ale często są tworzone za pomocą interfejsów API wysokiego poziomu, takich jaktf.keras.layers lub tf.keras.Model .

Najłatwiejszym sposobem zarządzania zmiennymi jest dołączanie ich do obiektów Pythona, a następnie odwoływanie się do tych obiektów.

Podklasy tf.train.Checkpoint , tf.keras.layers.Layer i tf.keras.Model automatycznie śledzą zmienne przypisane do ich atrybutów. Poniższy przykład tworzy prosty model liniowy, a następnie zapisuje punkty kontrolne, które zawierają wartości dla wszystkich zmiennych modelu.

Możesz łatwo zapisać punkt kontrolny modelu za pomocą Model.save_weights .

Ręczne punkty kontrolne

Ustawiać

Aby pomóc zademonstrować wszystkie funkcje tf.train.Checkpoint , zdefiniuj zestaw danych zabawki i krok optymalizacji:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Utwórz obiekty punktów kontrolnych

Użyj obiektu tf.train.Checkpoint aby ręcznie utworzyć punkt kontrolny, w którym obiekty, które mają być tf.train.Checkpoint są ustawiane jako atrybuty obiektu.

tf.train.CheckpointManager może być również pomocny w zarządzaniu wieloma punktami kontrolnymi.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Wytrenuj i sprawdzaj model

Poniższa pętla szkoleniowa tworzy instancję modelu i optymalizatora, a następnie gromadzi je w obiekcie tf.train.Checkpoint . Wywołuje krok szkoleniowy w pętli dla każdej partii danych i okresowo zapisuje punkty kontrolne na dysku.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.00
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 22.42
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.86
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 9.40
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.20

Przywróć i kontynuuj trening

Po pierwszym cyklu treningowym możesz zaliczyć nowego modela i menadżera, ale rozpocznij trening dokładnie tam, gdzie skończyłeś:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.19
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.66
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.90
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.32
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.34

Obiekt tf.train.CheckpointManager usuwa stare punkty kontrolne. Powyżej jest skonfigurowany tak, aby przechowywać tylko trzy najnowsze punkty kontrolne.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Te ścieżki, np. './tf_ckpts/ckpt-10' , nie są plikami na dysku. Zamiast tego są przedrostkami dla pliku index i co najmniej jednego pliku danych, który zawiera wartości zmiennych. Te przedrostki są zgrupowane razem w jednym pliku checkpoint ( './tf_ckpts/checkpoint' ), w którym CheckpointManager zapisuje swój stan.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mechanika ładowania

TensorFlow dopasowuje zmienne do wartości punktów kontrolnych, przechodząc przez skierowany wykres z nazwanymi krawędziami, zaczynając od ładowanego obiektu. Nazwy krawędzi zwykle pochodzą z nazw atrybutów w obiektach, na przykład "l1" w self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint używa swoich nazw argumentów słów kluczowych, tak jak w "step" w tf.train.Checkpoint(step=...) .

Wykres zależności z powyższego przykładu wygląda następująco:

Wizualizacja wykresu zależności dla przykładowej pętli szkoleniowej

Optymalizator jest na czerwono, zwykłe zmienne na niebiesko, a zmienne optymalizatora na pomarańczowe. Pozostałe węzły - na przykład reprezentujące tf.train.Checkpoint - są w kolorze czarnym.

Zmienne boksów są częścią stanu optymalizatora, ale są tworzone dla określonej zmiennej. Na przykład krawędzie 'm' powyżej odpowiadają pędowi, który optymalizator Adama śledzi dla każdej zmiennej. Zmienne szczeliny są zapisywane w punkcie kontrolnym tylko wtedy, gdy zarówno zmienna, jak i optymalizator zostałyby zapisane, stąd przerywane krawędzie.

Wywołanie funkcji restore na obiekcie tf.train.Checkpoint kolejkuje żądane przywrócenia, przywracając wartości zmiennych, gdy tylko znajdzie się pasująca ścieżka z obiektu Checkpoint . Na przykład możesz załadować tylko odchylenie z modelu zdefiniowanego powyżej, rekonstruując jedną ścieżkę do niego przez sieć i warstwę.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.2704186 3.0526643 3.8114467 3.4453893 4.2802196]

Wykres zależności dla tych nowych obiektów jest znacznie mniejszym podgrafem większego punktu kontrolnego, który napisałeś powyżej. Zawiera tylko odchylenie i licznik tf.train.Checkpoint które tf.train.Checkpoint używa do numerowania punktów kontrolnych.

Wizualizacja podgrafu dla zmiennej odchylenia

restore zwraca obiekt stanu, który ma opcjonalne potwierdzenia. Wszystkie obiekty utworzone w nowym status.assert_existing_objects_matched Checkpoint zostały przywrócone, więc status.assert_existing_objects_matched przechodzi.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2a4cbccb38>

W punkcie kontrolnym znajduje się wiele obiektów, które nie zostały dopasowane, w tym jądro warstwy i zmienne optymalizatora. status.assert_consumed przechodzi tylko wtedy, gdy punkt kontrolny i program dokładnie pasują, i status.assert_consumed tutaj wyjątek.

Opóźnione uzupełnienia

Obiekty Layer w TensorFlow mogą opóźniać tworzenie zmiennych do ich pierwszego wywołania, gdy dostępne są kształty wejściowe. Na przykład kształt jądra warstwy Dense zależy zarówno od kształtów wejściowych, jak i wyjściowych warstwy, więc kształt wyjściowy wymagany jako argument konstruktora nie jest wystarczającą informacją, aby samodzielnie utworzyć zmienną. Ponieważ wywołanie Layer odczytuje również wartość zmiennej, przywrócenie musi nastąpić między utworzeniem zmiennej a jej pierwszym użyciem.

Aby obsługiwać ten idiom, tf.train.Checkpoint odtwarza kolejki, które nie mają jeszcze pasującej zmiennej.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6544    4.6866627 4.729344  4.9574785 4.8010526]]

Ręczne sprawdzanie punktów kontrolnych

tf.train.load_checkpoint zwraca CheckpointReader który zapewnia niższy poziom dostępu do zawartości punktu kontrolnego. Zawiera odwzorowania z klucza każdej zmiennej na kształt i typ każdej zmiennej w punkcie kontrolnym. Kluczem zmiennej jest ścieżka do obiektu, jak na wykresach powyżej.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Więc jeśli interesuje Cię wartość net.l1.kernel , możesz uzyskać wartość za pomocą następującego kodu:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Udostępnia również metodę get_tensor umożliwiającą sprawdzenie wartości zmiennej:

reader.get_tensor(key)
array([[4.6544   , 4.6866627, 4.729344 , 4.9574785, 4.8010526]],
      dtype=float32)

Śledzenie list i słowników

Podobnie jak w przypadku bezpośredniego przypisywania atrybutów, takich jak self.l1 = tf.keras.layers.Dense(5) , przypisywanie list i słowników do atrybutów będzie śledzić ich zawartość.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Możesz zauważyć obiekty opakowaniowe dla list i słowników. Te opakowania są wersjami bazowych struktur danych, które można kontrolować. Podobnie jak w przypadku ładowania opartego na atrybutach, te opakowania przywracają wartość zmiennej, gdy tylko zostanie ona dodana do kontenera.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

To samo śledzenie jest automatycznie stosowane do podklas tf.keras.Model i może być używane na przykład do śledzenia list warstw.

streszczenie

Obiekty TensorFlow zapewniają łatwy automatyczny mechanizm zapisywania i przywracania wartości zmiennych, których używają.