Bu sayfa, Cloud Translation API ile çevrilmiştir.
Switch to English

Eğitim kontrol noktaları

TensorFlow.org'da görüntüleyin Google Colab'de çalıştırın Kaynağı GitHub'da görüntüleyin Defteri indirin

"TensorFlow modelini kaydetme" ifadesi tipik olarak iki şeyden biri anlamına gelir:

  1. Kontrol noktaları, VEYA
  2. SavedModel.

Denetim noktaları, bir model tarafından kullanılan tüm parametrelerin ( tf.Variable nesneler) tam değerini yakalar. Kontrol noktaları, model tarafından tanımlanan hesaplamanın herhangi bir açıklamasını içermez ve bu nedenle tipik olarak yalnızca kaydedilen parametre değerlerini kullanacak kaynak kodu mevcut olduğunda yararlıdır.

Öte yandan SavedModel formatı, parametre değerlerine (kontrol noktası) ek olarak model tarafından tanımlanan hesaplamanın serileştirilmiş bir açıklamasını içerir. Bu biçimdeki modeller, modeli oluşturan kaynak koddan bağımsızdır. Bu nedenle TensorFlow Serving, TensorFlow Lite, TensorFlow.js veya diğer programlama dillerindeki programlar (C, C ++, Java, Go, Rust, C # vb. TensorFlow API'leri) aracılığıyla dağıtım için uygundurlar.

Bu kılavuz, kontrol noktaları yazmak ve okumak için API'leri kapsar.

Kurmak

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.keras eğitim API'lerinden tasarruf

Kaydetme ve geri yükleme ile ilgili tf.keras kılavuzuna bakın.

tf.keras.Model.save_weights , bir TensorFlow denetim noktası kaydeder.

net.save_weights('easy_checkpoint')

Kontrol noktaları yazma

TensorFlow modelinin kalıcı durumu tf.Variable nesnelerinde saklanır. Bunlar doğrudan yapılandırılabilir, ancak genellikletf.keras.layers veya tf.keras.Model gibitf.keras.layers düzey API'ler aracılığıyla oluşturulur.

Değişkenleri yönetmenin en kolay yolu, onları Python nesnelerine eklemek ve ardından bu nesnelere başvurmaktır.

tf.train.Checkpoint , tf.keras.layers.Layer ve tf.keras.Model alt sınıfları, özniteliklerine atanan değişkenleri otomatik olarak izler. Aşağıdaki örnek, basit bir doğrusal model oluşturur, ardından modelin tüm değişkenleri için değerleri içeren kontrol noktaları yazar.

Model.save_weights ile bir model kontrol noktasını kolayca kaydedebilirsiniz.

Manuel kontrol noktası belirleme

Kurmak

tf.train.Checkpoint tüm özelliklerini göstermeye yardımcı olmak için bir oyuncak veri seti ve optimizasyon adımı tanımlayın:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Kontrol noktası nesnelerini oluşturun

Manuel olarak bir kontrol noktası oluşturmak için bir tf.train.Checkpoint nesnesine ihtiyacınız olacaktır. Kontrol etmek istediğiniz nesnelerin nesne üzerinde nitelikler olarak ayarlandığı yer.

Bir tf.train.CheckpointManager , birden çok denetim noktasını yönetmek için de yardımcı olabilir.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Modeli eğitin ve kontrol edin

Aşağıdaki eğitim döngüsü, modelin ve optimize edicinin bir örneğini oluşturur, ardından bunları bir tf.train.Checkpoint nesnesinde toplar. Her veri grubu üzerinde bir döngü içinde eğitim adımını çağırır ve kontrol noktalarını düzenli aralıklarla diske yazar.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 27.29
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 20.70
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 14.14
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 7.68
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 2.07

Geri yükleyin ve eğitime devam edin

İlkinden sonra yeni bir model ve menajeri geçebilirsiniz, ancak eğitimi tam olarak kaldığınız yerden devam ettirin:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.04
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.81
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.67
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.22

tf.train.CheckpointManager nesnesi eski denetim noktalarını siler. Yukarıda, yalnızca en son üç kontrol noktasını tutacak şekilde yapılandırılmıştır.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Bu yollar, örneğin './tf_ckpts/ckpt-10' diskteki dosyalar değildir. Bunun yerine, bir index dosyası ve değişken değerleri içeren bir veya daha fazla veri dosyası için öneklerdir. Bu önekler, CheckpointManager durumunu kaydettiği tek bir checkpoint dosyasında ( './tf_ckpts/checkpoint' ) birlikte gruplanır.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mekaniği yükleme

TensorFlow, yüklenen nesneden başlayarak, adlandırılmış kenarları olan yönlendirilmiş bir grafiğin üzerinden geçerek değişkenleri kontrol noktalı değerlerle eşleştirir. Kenar adları genellikle nesnelerdeki öznitelik adlarından gelir, örneğin self.l1 = tf.keras.layers.Dense(5) içindeki "l1" . tf.train.Checkpoint , tf.train.Checkpoint(step=...) "step" olduğu gibi kendi anahtar kelime argüman adlarını kullanır.

Yukarıdaki örnekteki bağımlılık grafiği şuna benzer:

Örnek eğitim döngüsü için bağımlılık grafiğinin görselleştirilmesi

Optimize edici kırmızı, normal değişkenler mavi ve optimize edici alan değişkenleri turuncu renkte. Örneğin tf.train.Checkpoint temsil eden diğer düğümler siyahtır.

Yuva değişkenleri, optimize edicinin durumunun bir parçasıdır, ancak belirli bir değişken için oluşturulur. Örneğin, yukarıdaki 'm' kenarları, Adam optimize edicinin her değişken için izlediği momentuma karşılık gelir. Yuva değişkenleri, yalnızca değişken ve optimize edicinin her ikisi de, dolayısıyla kesikli kenarlar kaydedilecekse bir kontrol noktasına kaydedilir.

Bir tf.train.Checkpoint nesnesinde restore() 'nin çağrılması, istenen geri tf.train.Checkpoint sıraya tf.train.Checkpoint ve Checkpoint nesnesinden eşleşen bir yol olduğu anda değişken değerlerini geri yükler. Örneğin, yukarıda tanımladığımız modelden sadece önyargıyı, ona ağ ve katman üzerinden bir yolu yeniden yapılandırarak yükleyebilirsiniz.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[1.9204693 3.3416245 2.7418654 2.9938312 4.2179084]

Bu yeni nesneler için bağımlılık grafiği, yukarıda yazdığımız daha büyük kontrol noktasının çok daha küçük bir alt grafiğidir. Yalnızca önyargı ve tf.train.Checkpoint denetim noktalarını tf.train.Checkpoint kullandığı bir kaydetme sayacı içerir.

Sapma değişkeni için bir alt grafiğin görselleştirilmesi

restore() , isteğe bağlı iddialara sahip bir durum nesnesi döndürür. Yeni Checkpoint oluşturulan tüm nesneler geri yüklendi, bu nedenle status.assert_existing_objects_matched() geçer.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fbab44668d0>

Kontrol noktasında, katmanın çekirdeği ve optimize edicinin değişkenleri de dahil olmak üzere eşleşmeyen birçok nesne vardır. status.assert_consumed() yalnızca denetim noktası ve program tam olarak eşleştiğinde geçer ve burada bir istisna atar.

Gecikmiş restorasyonlar

TensorFlow'daki Layer nesneleri, giriş şekilleri mevcut olduğunda değişkenlerin oluşturulmasını ilk çağrılarına kadar geciktirebilir. Örneğin, bir Dense katman çekirdeğinin şekli hem katmanın giriş hem de çıkış şekillerine bağlıdır ve bu nedenle yapıcı argümanı olarak gereken çıktı şekli, değişkeni kendi başına oluşturmak için yeterli bilgi değildir. Bir Layer çağırmak aynı zamanda değişkenin değerini de okuduğundan, değişkenin oluşturulması ile ilk kullanımı arasında bir geri yükleme gerçekleşmelidir.

Bu deyimi desteklemek için tf.train.Checkpoint kuyrukları, henüz eşleşen bir değişkeni olmayanları geri yükler.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6967545 4.656775  4.8799973 4.961913  4.893354 ]]

Kontrol noktalarını manuel olarak incelemek

tf.train.load_checkpoint , denetim noktası içeriğine daha düşük düzeyde erişim sağlayan bir CheckpointReader döndürür. Kontrol noktasındaki her değişken için her bir değişken anahtarından şekil ve dtype için eşlemeler içerir. Bir değişkenin anahtarı, yukarıda görüntülenen grafiklerde olduğu gibi, nesnenin yoludur.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Dolayısıyla, net.l1.kernel değeriyle ilgileniyorsanız, aşağıdaki kodla değeri elde edebilirsiniz:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Ayrıca, bir değişkenin değerini incelemenizi sağlayan bir get_tensor yöntemi sağlar:

reader.get_tensor(key)
array([[4.6967545, 4.656775 , 4.8799973, 4.961913 , 4.893354 ]],
      dtype=float32)

Liste ve sözlük takibi

self.l1 = tf.keras.layers.Dense(5) gibi doğrudan öznitelik atamalarında olduğu gibi, özniteliklere liste ve sözlük atamak içeriklerini izler.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Listeler ve sözlükler için sarmalayıcı nesneleri fark edebilirsiniz. Bu sarmalayıcılar, temeldeki veri yapılarının kontrol noktalı sürümleridir. Öznitelik tabanlı yüklemede olduğu gibi, bu sarmalayıcılar bir değişkenin değerini konteynere eklenir eklenmez geri yükler.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Aynı izleme otomatik olarak tf.keras.Model alt sınıflarına uygulanır ve örneğin katman listelerini izlemek için kullanılabilir.

Özet

TensorFlow nesneleri, kullandıkları değişkenlerin değerlerini kaydetmek ve geri yüklemek için kolay bir otomatik mekanizma sağlar.