Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Treningowe punkty kontrolne

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Wyrażenie „Zapisywanie modelu TensorFlow” zazwyczaj oznacza jedną z dwóch rzeczy:

  1. Punkty kontrolne LUB
  2. SavedModel.

Punkty kontrolne przechwytują dokładną wartość wszystkich parametrów (obiektów tf.Variable ) używanych przez model. Punkty kontrolne nie zawierają żadnego opisu obliczeń zdefiniowanych przez model i dlatego są zwykle przydatne tylko wtedy, gdy dostępny jest kod źródłowy, który będzie używał zapisanych wartości parametrów.

Z drugiej strony format SavedModel zawiera serializowany opis obliczeń zdefiniowanych przez model oprócz wartości parametrów (punkt kontrolny). Modele w tym formacie są niezależne od kodu źródłowego, który utworzył model. Są zatem odpowiednie do wdrażania za pośrednictwem usług TensorFlow Serving, TensorFlow Lite, TensorFlow.js lub programów w innych językach programowania (C, C ++, Java, Go, Rust, C # itp. API TensorFlow).

Ten przewodnik obejmuje interfejsy API do pisania i odczytywania punktów kontrolnych.

Ustawiać

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Zapisywanie z tf.keras szkoleniowych interfejsów API

Zobacz przewodnik tf.keras dotyczący zapisywania i przywracania.

tf.keras.Model.save_weights zapisuje punkt kontrolny TensorFlow.

net.save_weights('easy_checkpoint')

Pisanie punktów kontrolnych

Trwały stan modelu TensorFlow jest przechowywany w obiektach tf.Variable . Można je konstruować bezpośrednio, ale często są tworzone za pomocą interfejsów API wysokiego poziomu, takich jak tf.keras.layers lub tf.keras.Model .

Najłatwiejszym sposobem zarządzania zmiennymi jest dołączanie ich do obiektów Pythona, a następnie odwoływanie się do tych obiektów.

Podklasy tf.train.Checkpoint , tf.keras.layers.Layer i tf.keras.Model automatycznie śledzą zmienne przypisane do ich atrybutów. Poniższy przykład konstruuje prosty model liniowy, a następnie zapisuje punkty kontrolne, które zawierają wartości wszystkich zmiennych modelu.

Możesz łatwo zapisać punkt kontrolny modelu za pomocą Model.save_weights

Ręczne punkty kontrolne

Ustawiać

Aby pomóc zademonstrować wszystkie funkcje tf.train.Checkpoint zdefiniuj zestaw danych zabawki i krok optymalizacji:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Utwórz obiekty punktów kontrolnych

Aby ręcznie utworzyć punkt kontrolny, potrzebujesz obiektu tf.train.Checkpoint . Gdzie obiekty, które chcesz sprawdzić, są ustawione jako atrybuty obiektu.

tf.train.CheckpointManager może być również pomocny w zarządzaniu wieloma punktami kontrolnymi.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Wytrenuj i sprawdzaj model

Poniższa pętla szkoleniowa tworzy instancję modelu i optymalizatora, a następnie gromadzi je w obiekcie tf.train.Checkpoint . Wywołuje krok szkoleniowy w pętli dla każdej partii danych i okresowo zapisuje punkty kontrolne na dysku.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 32.40
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 25.82
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 19.26
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 12.77
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 6.48

Przywróć i kontynuuj szkolenie

Po pierwszym możesz przejść nowy model i menedżera, ale szkolenie na odbiór dokładnie tam, gdzie skończyłeś:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.85
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.88
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.44
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.41
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.25

Obiekt tf.train.CheckpointManager usuwa stare punkty kontrolne. Powyżej jest skonfigurowany do przechowywania tylko trzech ostatnich punktów kontrolnych.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Te ścieżki, np. './tf_ckpts/ckpt-10' , nie są plikami na dysku. Zamiast tego są przedrostkami dla pliku index i jednego lub więcej plików danych, które zawierają wartości zmiennych. Te przedrostki są zgrupowane razem w jednym pliku checkpoint ( './tf_ckpts/checkpoint' ), w którym CheckpointManager zapisuje swój stan.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mechanika ładowania

TensorFlow dopasowuje zmienne do wartości punktów kontrolnych, przechodząc przez skierowany wykres z nazwanymi krawędziami, zaczynając od ładowanego obiektu. Nazwy krawędzi zwykle pochodzą z nazw atrybutów w obiektach, na przykład "l1" w self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint używa swoich nazw argumentów słów kluczowych, jak w "step" w tf.train.Checkpoint(step=...) .

Wykres zależności z powyższego przykładu wygląda następująco:

Wizualizacja wykresu zależności dla przykładowej pętli szkoleniowej

Z optymalizatorem na czerwono, zwykłymi zmiennymi na niebiesko i zmiennymi optymalizatora na pomarańczowo. Pozostałe węzły, na przykład reprezentujące tf.train.Checkpoint , są czarne.

Zmienne boksów są częścią stanu optymalizatora, ale są tworzone dla określonej zmiennej. Na przykład krawędzie 'm' powyżej odpowiadają pędowi, który optymalizator Adama śledzi dla każdej zmiennej. Zmienne szczeliny są zapisywane w punkcie kontrolnym tylko wtedy, gdy zarówno zmienna, jak i optymalizator zostałyby zapisane, stąd przerywane krawędzie.

Wywołanie funkcji restore() w obiekcie tf.train.Checkpoint kolejkuje żądane przywrócenia, przywracając wartości zmiennych, gdy tylko znajdzie się pasująca ścieżka z obiektu Checkpoint . Na przykład możesz załadować tylko odchylenie z modelu, który zdefiniowaliśmy powyżej, rekonstruując jedną ścieżkę do niego przez sieć i warstwę.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[3.4461102 3.030825  4.4315968 3.5077076 4.7258596]

Wykres zależności dla tych nowych obiektów jest znacznie mniejszym podgrafem większego punktu kontrolnego, który napisaliśmy powyżej. Zawiera tylko odchylenie i licznik tf.train.Checkpoint które tf.train.Checkpoint używa do numerowania punktów kontrolnych.

Wizualizacja podgrafu dla zmiennej odchylenia

restore() zwraca obiekt stanu, który ma opcjonalne potwierdzenia. Wszystkie obiekty utworzone w nowym status.assert_existing_objects_matched() Checkpoint zostały przywrócone, więc status.assert_existing_objects_matched() przechodzi.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fec144bd080>

W punkcie kontrolnym znajduje się wiele obiektów, które nie zostały dopasowane, w tym jądro warstwy i zmienne optymalizatora. status.assert_consumed() przechodzi tylko wtedy, gdy punkt kontrolny i program dokładnie pasują, i status.assert_consumed() tutaj wyjątek.

Opóźnione uzupełnienia

Obiekty Layer w TensorFlow mogą opóźniać tworzenie zmiennych do ich pierwszego wywołania, gdy dostępne są kształty wejściowe. Na przykład kształt jądra warstwy Dense zależy zarówno od kształtów wejściowych, jak i wyjściowych warstwy, więc kształt wyjściowy wymagany jako argument konstruktora nie jest wystarczającą informacją, aby samodzielnie utworzyć zmienną. Ponieważ wywołanie Layer odczytuje również wartość zmiennej, przywrócenie musi nastąpić między utworzeniem zmiennej a jej pierwszym użyciem.

Aby obsługiwać ten idiom, tf.train.Checkpoint odtwarza kolejki, które nie mają jeszcze pasującej zmiennej.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.4598393 4.677273  4.655946  4.926899  4.79748  ]]

Ręczne sprawdzanie punktów kontrolnych

tf.train.list_variables zawiera listę kluczy punktów kontrolnych i kształtów zmiennych w punkcie kontrolnym. Klucze punktów kontrolnych to ścieżki na wykresie wyświetlonym powyżej.

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

Śledzenie list i słowników

Podobnie jak w przypadku bezpośredniego przypisywania atrybutów, takich jak self.l1 = tf.keras.layers.Dense(5) , przypisywanie list i słowników do atrybutów będzie śledzić ich zawartość.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Możesz zauważyć obiekty opakowujące listy i słowniki. Te opakowania są wersjami bazowych struktur danych, które można kontrolować. Podobnie jak w przypadku ładowania opartego na atrybutach, te opakowania przywracają wartość zmiennej, gdy tylko zostanie ona dodana do kontenera.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

To samo śledzenie jest automatycznie stosowane do podklas tf.keras.Model i może być używane na przykład do śledzenia list warstw.

streszczenie

Obiekty TensorFlow zapewniają łatwy automatyczny mechanizm zapisywania i odtwarzania wartości używanych przez nie zmiennych.