Google I / O'daki önemli notları, ürün oturumlarını, atölyeleri ve daha fazlasını izleyin Oynatma listesine bakın

Eğitim kontrol noktaları

TensorFlow.org'da görüntüleyin Google Colab'de çalıştırın Kaynağı GitHub'da görüntüleyin Defteri indirin

"TensorFlow modelini kaydetme" ifadesi tipik olarak iki şeyden biri anlamına gelir:

  1. Kontrol noktaları, VEYA
  2. SavedModel.

Denetim noktaları, bir model tarafından kullanılan tüm parametrelerin ( tf.Variable nesneler) tam değerini yakalar. Kontrol noktaları, model tarafından tanımlanan hesaplamanın herhangi bir açıklamasını içermez ve bu nedenle tipik olarak yalnızca kaydedilen parametre değerlerini kullanacak kaynak kodu mevcut olduğunda yararlıdır.

Öte yandan SavedModel biçimi, parametre değerlerine (kontrol noktası) ek olarak model tarafından tanımlanan hesaplamanın serileştirilmiş bir açıklamasını içerir. Bu biçimdeki modeller, modeli oluşturan kaynak koddan bağımsızdır. Bu nedenle TensorFlow Serving, TensorFlow Lite, TensorFlow.js veya diğer programlama dillerindeki programlar (C, C ++, Java, Go, Rust, C # vb. TensorFlow API'leri) aracılığıyla dağıtım için uygundurlar.

Bu kılavuz, denetim noktaları yazmak ve okumak için API'leri kapsar.

Kurulum

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.keras eğitim API'lerinden tasarruf

Kaydetme ve geri yükleme ile ilgili tf.keras kılavuzuna bakın.

tf.keras.Model.save_weights , bir TensorFlow denetim noktası kaydeder.

net.save_weights('easy_checkpoint')

Kontrol noktaları yazma

TensorFlow modelinin kalıcı durumu tf.Variable nesnelerinde saklanır. Bunlar doğrudan yapılandırılabilir, ancak genellikletf.keras.layers veya tf.keras.Model gibitf.keras.layers düzey API'ler aracılığıyla oluşturulur.

Değişkenleri yönetmenin en kolay yolu, onları Python nesnelerine eklemek ve ardından bu nesnelere başvurmaktır.

tf.train.Checkpoint , tf.keras.layers.Layer ve tf.keras.Model alt sınıfları, özniteliklerine atanan değişkenleri otomatik olarak izler. Aşağıdaki örnek, basit bir doğrusal model oluşturur, ardından modelin tüm değişkenleri için değerleri içeren denetim noktaları yazar.

Model.save_weights ile bir model kontrol noktasını kolayca kaydedebilirsiniz.

Manuel kontrol noktası belirleme

Kurulum

tf.train.Checkpoint tüm özelliklerini göstermeye yardımcı olmak için bir oyuncak veri kümesi ve optimizasyon adımı tanımlayın:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Kontrol noktası nesnelerini oluşturun

Kontrol noktasını el ile oluşturmak için bir tf.train.Checkpoint nesnesi kullanın, burada kontrol etmek istediğiniz nesneler nesnede öznitelikler olarak ayarlanır.

Bir tf.train.CheckpointManager , birden çok denetim noktasını yönetmek için de yardımcı olabilir.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Modeli eğitin ve kontrol edin

Aşağıdaki eğitim döngüsü, modelin ve optimize edicinin bir örneğini oluşturur, ardından bunları bir tf.train.Checkpoint nesnesinde toplar. Her veri grubu üzerinde bir döngü içinde eğitim adımını çağırır ve kontrol noktalarını düzenli aralıklarla diske yazar.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.00
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 22.42
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.86
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 9.40
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.20

Geri yükleyin ve eğitime devam edin

İlk eğitim döngüsünden sonra yeni bir model ve yönetici geçebilirsiniz, ancak eğitimi tam olarak kaldığınız yerden devam ettirebilirsiniz:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.19
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.66
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.90
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.32
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.34

tf.train.CheckpointManager nesnesi eski denetim noktalarını siler. Yukarıda, yalnızca en son üç kontrol noktasını tutacak şekilde yapılandırılmıştır.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Bu yollar, örneğin './tf_ckpts/ckpt-10' diskteki dosyalar değildir. Bunun yerine, bir index dosyası ve değişken değerleri içeren bir veya daha fazla veri dosyası için öneklerdir. Bu önekler, CheckpointManager durumunu kaydettiği tek bir checkpoint dosyasında ( './tf_ckpts/checkpoint' ) birlikte gruplanır.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mekaniği yükleme

TensorFlow, yüklenen nesneden başlayarak adlandırılmış kenarları olan yönlendirilmiş bir grafiğin üzerinden geçerek değişkenleri kontrol noktalı değerlerle eşleştirir. Kenar adları tipik olarak nesnelerdeki öznitelik adlarından gelir, örneğin self.l1 = tf.keras.layers.Dense(5) içindeki "l1" . tf.train.Checkpoint , tf.train.Checkpoint(step=...) "step" olduğu gibi, anahtar kelime argüman adlarını kullanır.

Yukarıdaki örnekteki bağımlılık grafiği şuna benzer:

Örnek eğitim döngüsü için bağımlılık grafiğinin görselleştirilmesi

Optimize edici kırmızı, normal değişkenler mavi ve optimize edici alan değişkenleri turuncudur. Diğer düğümler - örneğin, tf.train.Checkpoint temsil eden - siyah renktedir.

Yuva değişkenleri, optimize edicinin durumunun bir parçasıdır, ancak belirli bir değişken için oluşturulur. Örneğin, yukarıdaki 'm' kenarları, Adam optimize edicinin her değişken için izlediği momentuma karşılık gelir. Yuva değişkenleri, yalnızca değişken ve optimize edicinin her ikisi de, dolayısıyla kesikli kenarlar kaydedilecekse bir kontrol noktasına kaydedilir.

Bir tf.train.Checkpoint nesnesinde restore tf.train.Checkpoint , istenen geri tf.train.Checkpoint sıraya alır ve Checkpoint nesnesinden eşleşen bir yol olduğu anda değişken değerlerini geri yükler. Örneğin, yukarıda tanımladığınız modelden sadece önyargıyı, ağ ve katman aracılığıyla ona giden bir yolu yeniden yapılandırarak yükleyebilirsiniz.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.2704186 3.0526643 3.8114467 3.4453893 4.2802196]

Bu yeni nesneler için bağımlılık grafiği, yukarıda yazdığınız daha büyük kontrol noktasının çok daha küçük bir alt grafiğidir. Yalnızca önyargı ve tf.train.Checkpoint denetim noktalarını tf.train.Checkpoint kullandığı bir kaydetme sayacı içerir.

Sapma değişkeni için bir alt grafiğin görselleştirilmesi

restore , isteğe bağlı iddialara sahip bir durum nesnesi döndürür. Yeni Checkpoint oluşturulan tüm nesneler geri yüklendi, bu nedenle status.assert_existing_objects_matched geçer.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2a4cbccb38>

Kontrol noktasında, katmanın çekirdeği ve optimize edicinin değişkenleri dahil, eşleşmeyen birçok nesne vardır. status.assert_consumed yalnızca denetim noktası ve program tam olarak status.assert_consumed geçer ve burada bir istisna atar.

Gecikmiş restorasyonlar

TensorFlow'daki Layer nesneleri, giriş şekilleri mevcut olduğunda değişkenlerin oluşturulmasını ilk çağrılarına kadar geciktirebilir. Örneğin, Dense katman çekirdeğinin şekli hem katmanın giriş hem de çıkış şekillerine bağlıdır ve bu nedenle yapıcı argümanı olarak gereken çıktı şekli, değişkeni kendi başına oluşturmak için yeterli bilgi değildir. Bir Layer çağırmak değişkenin değerini de okuduğundan, değişkenin oluşturulması ve ilk kullanımı arasında bir geri yükleme gerçekleşmelidir.

Bu deyimi desteklemek için tf.train.Checkpoint kuyrukları, henüz eşleşen bir değişkeni olmayanları geri yükler.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6544    4.6866627 4.729344  4.9574785 4.8010526]]

Kontrol noktalarını manuel olarak incelemek

tf.train.load_checkpoint , denetim noktası içeriğine daha düşük düzeyde erişim sağlayan bir CheckpointReader döndürür. Her değişkenin anahtarından kontrol noktasındaki her değişken için şekil ve dtype için eşlemeler içerir. Bir değişkenin anahtarı, yukarıda görüntülenen grafiklerde olduğu gibi, nesnenin yoludur.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Dolayısıyla, net.l1.kernel değeriyle ilgileniyorsanız, aşağıdaki kodla değeri elde edebilirsiniz:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Ayrıca bir değişkenin değerini incelemenizi sağlayan bir get_tensor yöntemi sağlar:

reader.get_tensor(key)
array([[4.6544   , 4.6866627, 4.729344 , 4.9574785, 4.8010526]],
      dtype=float32)

Liste ve sözlük takibi

self.l1 = tf.keras.layers.Dense(5) gibi doğrudan öznitelik atamalarında olduğu gibi, özniteliklere liste ve sözlük atamak içeriklerini izler.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Listeler ve sözlükler için sarmalayıcı nesneleri fark edebilirsiniz. Bu sarmalayıcılar, temeldeki veri yapılarının kontrol noktalı sürümleridir. Öznitelik tabanlı yüklemede olduğu gibi, bu sarmalayıcılar bir değişkenin değerini kaba eklenir eklenmez geri yükler.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Aynı izleme otomatik olarak tf.keras.Model alt sınıflarına uygulanır ve örneğin katman listelerini izlemek için kullanılabilir.

Özet

TensorFlow nesneleri, kullandıkları değişkenlerin değerlerini kaydetmek ve geri yüklemek için kolay bir otomatik mekanizma sağlar.