Eğitim kontrol noktaları

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

"Bir TensorFlow modelini kaydetme" ifadesi tipik olarak iki şeyden biri anlamına gelir:

  1. Kontrol noktaları, VEYA
  2. Kayıtlı Model.

Kontrol noktaları, bir model tarafından kullanılan tüm parametrelerin ( tf.Variable nesneleri) tam değerini yakalar. Kontrol noktaları, model tarafından tanımlanan hesaplamanın herhangi bir tanımını içermez ve bu nedenle, tipik olarak yalnızca kaydedilen parametre değerlerini kullanacak kaynak kodu mevcut olduğunda faydalıdır.

SavedModel formatı ise parametre değerlerine (kontrol noktası) ek olarak model tarafından tanımlanan hesaplamanın serileştirilmiş bir açıklamasını içerir. Bu formattaki modeller, modeli oluşturan kaynak koddan bağımsızdır. Bu nedenle TensorFlow Serving, TensorFlow Lite, TensorFlow.js veya diğer programlama dillerindeki programlar (C, C++, Java, Go, Rust, C# vb. TensorFlow API'leri) aracılığıyla dağıtım için uygundurlar.

Bu kılavuz, kontrol noktaları yazmak ve okumak için API'leri kapsar.

Kurmak

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
-yer tutucu2 l10n-yer
net = Net()

tf.keras eğitim API'lerinden kaydetme

Kaydetme ve geri yükleme ile ilgili tf.keras kılavuzuna bakın.

tf.keras.Model.save_weights , bir TensorFlow kontrol noktası kaydeder.

net.save_weights('easy_checkpoint')

Kontrol noktaları yazma

Bir TensorFlow modelinin kalıcı durumu, tf.Variable nesnelerinde depolanır. Bunlar doğrudan oluşturulabilir, ancak genellikle tf.keras.layers veya tf.keras.Model gibi üst düzey API'ler aracılığıyla oluşturulur.

Değişkenleri yönetmenin en kolay yolu, onları Python nesnelerine eklemek ve ardından bu nesnelere referans vermektir.

tf.train.Checkpoint , tf.keras.layers.Layer ve tf.keras.Model alt sınıfları, özniteliklerine atanan değişkenleri otomatik olarak izler. Aşağıdaki örnek, basit bir doğrusal model oluşturur, ardından modelin tüm değişkenleri için değerler içeren kontrol noktaları yazar.

Model.save_weights ile bir model kontrol noktasını kolayca kaydedebilirsiniz.

Manuel kontrol noktası

Kurmak

tf.train.Checkpoint tüm özelliklerini göstermeye yardımcı olmak için bir oyuncak veri kümesi ve optimizasyon adımı tanımlayın:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
tutucu5 l10n-yer
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Kontrol noktası nesnelerini oluşturun

Kontrol noktası oluşturmak istediğiniz nesnelerin nesne üzerinde nitelikler olarak ayarlandığı bir kontrol noktası oluşturmak için bir tf.train.Checkpoint nesnesi kullanın.

Bir tf.train.CheckpointManager , birden çok kontrol noktasını yönetmek için de yardımcı olabilir.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Modeli eğitin ve kontrol edin

Aşağıdaki eğitim döngüsü, modelin ve optimize edicinin bir örneğini oluşturur ve ardından bunları bir tf.train.Checkpoint nesnesinde toplar. Her veri grubu üzerinde eğitim adımını bir döngü içinde çağırır ve diske periyodik olarak kontrol noktaları yazar.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
-yer tutucu9 l10n-yer
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

Eğitimi geri yükle ve devam et

İlk eğitim döngüsünden sonra yeni bir model ve yöneticiye geçebilirsiniz, ancak eğitime tam olarak kaldığınız yerden devam edebilirsiniz:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
tutucu11 l10n-yer
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

tf.train.CheckpointManager nesnesi eski kontrol noktalarını siler. Yukarıda, yalnızca en son üç kontrol noktasını tutacak şekilde yapılandırılmıştır.

print(manager.checkpoints)  # List the three remaining checkpoints
tutucu13 l10n-yer
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Bu yollar, örneğin './tf_ckpts/ckpt-10' diskteki dosyalar değildir. Bunun yerine, bir index dosyası ve değişken değerlerini içeren bir veya daha fazla veri dosyası için öneklerdir. Bu önekler, CheckpointManager durumunu kaydettiği tek bir checkpoint dosyasında ( './tf_ckpts/checkpoint' ) birlikte gruplanır.

ls ./tf_ckpts
tutucu15 l10n-yer
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

yükleme mekaniği

TensorFlow, yüklenen nesneden başlayarak adlandırılmış kenarlara sahip yönlendirilmiş bir grafiğin üzerinden geçerek değişkenleri kontrol noktası değerleriyle eşleştirir. Kenar adları tipik olarak nesnelerdeki öznitelik adlarından gelir, örneğin self.l1 = tf.keras.layers.Dense(5) içindeki "l1" . tf.train.Checkpoint , tf.train.Checkpoint(step=...) içindeki "step" da olduğu gibi anahtar kelime bağımsız değişken adlarını kullanır.

Yukarıdaki örnekteki bağımlılık grafiği şöyle görünür:

Örnek eğitim döngüsü için bağımlılık grafiğinin görselleştirilmesi

Optimize edici kırmızı, normal değişkenler mavi ve optimize edici slot değişkenleri turuncu. Diğer düğümler (örneğin, tf.train.Checkpoint temsil eden) siyah renktedir.

Yuva değişkenleri, optimize edicinin durumunun bir parçasıdır, ancak belirli bir değişken için oluşturulur. Örneğin, yukarıdaki 'm' kenarları, Adam optimize edicinin her değişken için izlediği momentuma karşılık gelir. Yuva değişkenleri, yalnızca değişken ve optimize edicinin her ikisi de, dolayısıyla kesikli kenarlar kaydedilecekse bir kontrol noktasına kaydedilir.

Bir tf.train.Checkpoint nesnesindeki restore çağırmak, Checkpoint nesnesinden eşleşen bir yol bulunur bulunmaz değişken değerleri geri yükleyerek, istenen geri yüklemeleri sıraya alır. Örneğin, ağ ve katman üzerinden bir yolu yeniden yapılandırarak yukarıda tanımladığınız modelden yalnızca sapmayı yükleyebilirsiniz.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
tutucu17 l10n-yer
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

Bu yeni nesneler için bağımlılık grafiği, yukarıda yazdığınız daha büyük kontrol noktasının çok daha küçük bir alt grafiğidir. Yalnızca önyargıyı ve tf.train.Checkpoint kontrol noktalarını numaralandırmak için kullandığı bir kaydetme sayacını içerir.

Önyargı değişkeni için bir alt grafiğin görselleştirilmesi

restore , isteğe bağlı iddiaları olan bir durum nesnesi döndürür. Yeni Checkpoint oluşturulan tüm nesneler geri yüklendi, bu nedenle status.assert_existing_objects_matched .

status.assert_existing_objects_matched()
tutucu19 l10n-yer
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

Kontrol noktasında, katmanın çekirdeği ve optimize edicinin değişkenleri dahil, eşleşmeyen birçok nesne var. status.assert_consumed yalnızca kontrol noktası ve program tam olarak eşleşirse geçer ve burada bir istisna atar.

Ertelenmiş restorasyonlar

TensorFlow'daki Layer nesneleri, girdi şekilleri mevcut olduğunda, değişkenlerin oluşturulmasını ilk çağrılarına erteleyebilir. Örneğin, bir Dense katmanın çekirdeğinin şekli, katmanın hem girdi hem de çıktı şekillerine bağlıdır ve bu nedenle bir yapıcı argümanı olarak gereken çıktı şekli, değişkeni kendi başına oluşturmak için yeterli bilgi değildir. Bir Layer çağırmak aynı zamanda değişkenin değerini de okuduğundan, değişkenin oluşturulması ile ilk kullanımı arasında bir geri yükleme olması gerekir.

Bu deyimi desteklemek için tf.train.Checkpoint , henüz eşleşen bir değişkene sahip olmayan geri yüklemeleri erteler.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
tutucu21 l10n-yer
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

Kontrol noktalarını manuel olarak denetleme

tf.train.load_checkpoint , kontrol noktası içeriğine daha düşük düzeyde erişim sağlayan bir CheckpointReader döndürür. Her bir değişkenin anahtarından, kontrol noktasındaki her bir değişkenin şekline ve tipine eşlemeler içerir. Bir değişkenin anahtarı, yukarıda gösterilen grafiklerde olduğu gibi nesne yoludur.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
tutucu23 l10n-yer
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Yani net.l1.kernel değeriyle ilgileniyorsanız, değeri aşağıdaki kodla alabilirsiniz:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
tutucu25 l10n-yer
Shape: [1, 5]
Dtype: float32

Ayrıca bir değişkenin değerini incelemenize izin veren bir get_tensor yöntemi sağlar:

reader.get_tensor(key)
tutucu27 l10n-yer
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

nesne izleme

Kontrol noktaları, özniteliklerinden birinde ayarlanmış herhangi bir değişkeni veya izlenebilir nesneyi "izleyerek" tf.Variable nesnelerinin değerlerini kaydeder ve geri yükler. Bir kaydetme yürütülürken, değişkenler, erişilebilir tüm izlenen nesnelerden özyinelemeli olarak toplanır.

self.l1 = tf.keras.layers.Dense(5) gibi doğrudan nitelik atamalarında olduğu gibi, niteliklere listeler ve sözlükler atamak onların içeriklerini izleyecektir.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Listeler ve sözlükler için sarmalayıcı nesneler görebilirsiniz. Bu sarmalayıcılar, temel alınan veri yapılarının denetlenebilir sürümleridir. Öznitelik tabanlı yükleme gibi, bu sarmalayıcılar, kapsayıcıya eklenir eklenmez bir değişkenin değerini geri yükler.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
tutucu30 l10n-yer
ListWrapper([])

İzlenebilir nesneler arasında tf.train.Checkpoint , tf.Module ve alt sınıfları (örneğin keras.layers.Layer ve keras.Model ) ve tanınan Python kapsayıcıları bulunur:

  • dict (ve collections.OrderedDict )
  • list
  • tuple (ve collections.namedtuple , typing.NamedTuple )

Aşağıdakiler dahil diğer kapsayıcı türleri desteklenmez :

  • collections.defaultdict
  • set

Aşağıdakiler dahil tüm diğer Python nesneleri yoksayılır :

  • int
  • string
  • float

Özet

TensorFlow nesneleri, kullandıkları değişkenlerin değerlerini kaydetmek ve geri yüklemek için kolay bir otomatik mekanizma sağlar.