Bu sayfa, Cloud Translation API ile çevrilmiştir.
Switch to English

Eğitim kontrol noktaları

TensorFlow.org'da görüntüleyin Google Colab'de çalıştırın Kaynağı GitHub'da görüntüleyin Defteri indirin

"TensorFlow modelini kaydetme" ifadesi tipik olarak iki şeyden biri anlamına gelir:

  1. Kontrol noktaları, VEYA
  2. SavedModel.

Denetim noktaları, bir model tarafından kullanılan tüm parametrelerin ( tf.Variable nesneler) tam değerini yakalar. Kontrol noktaları, model tarafından tanımlanan hesaplamanın herhangi bir açıklamasını içermez ve bu nedenle tipik olarak yalnızca kaydedilen parametre değerlerini kullanacak kaynak kodu mevcut olduğunda yararlıdır.

Öte yandan SavedModel formatı, parametre değerlerine (kontrol noktası) ek olarak model tarafından tanımlanan hesaplamanın serileştirilmiş bir açıklamasını içerir. Bu biçimdeki modeller, modeli oluşturan kaynak koddan bağımsızdır. Bu nedenle TensorFlow Serving, TensorFlow Lite, TensorFlow.js veya diğer programlama dillerindeki programlar (C, C ++, Java, Go, Rust, C # vb. TensorFlow API'leri) aracılığıyla dağıtım için uygundurlar.

Bu kılavuz, kontrol noktaları yazmak ve okumak için API'leri kapsar.

Kurmak

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.keras eğitim API'lerinden tasarruf

Kaydetme ve geri yükleme ile ilgili tf.keras kılavuzuna bakın.

tf.keras.Model.save_weights , bir TensorFlow denetim noktası kaydeder.

net.save_weights('easy_checkpoint')

Kontrol noktaları yazma

TensorFlow modelinin kalıcı durumu tf.Variable nesnelerinde saklanır. Bunlar doğrudan yapılandırılabilir, ancak genellikle tf.keras.layers veya tf.keras.Model gibi tf.keras.layers düzey API'ler aracılığıyla oluşturulur.

Değişkenleri yönetmenin en kolay yolu, onları Python nesnelerine eklemek ve ardından bu nesnelere başvurmaktır.

tf.train.Checkpoint , tf.keras.layers.Layer ve tf.keras.Model alt sınıfları, özniteliklerine atanan değişkenleri otomatik olarak izler. Aşağıdaki örnek, basit bir doğrusal model oluşturur, ardından modelin tüm değişkenleri için değerleri içeren kontrol noktaları yazar.

Model.save_weights ile bir model kontrol noktasını kolayca kaydedebilirsiniz.

Manuel kontrol noktası belirleme

Kurmak

tf.train.Checkpoint tüm özelliklerini göstermeye yardımcı olmak için bir oyuncak veri seti ve optimizasyon adımı tanımlayın:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Kontrol noktası nesnelerini oluşturun

Manuel olarak bir kontrol noktası oluşturmak için bir tf.train.Checkpoint nesnesine ihtiyacınız olacaktır. Kontrol etmek istediğiniz nesnelerin, nesne üzerinde nitelikler olarak ayarlandığı yer.

Bir tf.train.CheckpointManager , birden çok denetim noktasını yönetmek için de yardımcı olabilir.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Modeli eğitin ve kontrol edin

Aşağıdaki eğitim döngüsü, modelin ve optimize edicinin bir örneğini oluşturur, ardından bunları bir tf.train.Checkpoint nesnesinde toplar. Her veri grubu için bir döngü içinde eğitim adımını çağırır ve kontrol noktalarını düzenli olarak diske yazar.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 28.15
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 21.56
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.00
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 8.52
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.25

Geri yükleyin ve eğitime devam edin

İlkinden sonra yeni bir model ve menajeri geçebilirsiniz, ancak eğitimi tam olarak kaldığınız yerden devam ettirin:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.76
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.65
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.51
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.25

tf.train.CheckpointManager nesnesi eski denetim noktalarını siler. Yukarıda, yalnızca en son üç kontrol noktasını tutacak şekilde yapılandırılmıştır.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Bu yollar, örneğin './tf_ckpts/ckpt-10' diskteki dosyalar değildir. Bunun yerine, bir index dosyası ve değişken değerleri içeren bir veya daha fazla veri dosyası için öneklerdir. Bu önekler, CheckpointManager durumunu kaydettiği tek bir checkpoint dosyasında ( './tf_ckpts/checkpoint' ) birlikte gruplanır.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mekaniği yükleme

TensorFlow, yüklenen nesneden başlayarak, adlandırılmış kenarları olan yönlendirilmiş bir grafiğin üzerinden geçerek değişkenleri kontrol noktalı değerlerle eşleştirir. Kenar adları tipik olarak nesnelerdeki öznitelik adlarından gelir, örneğin self.l1 = tf.keras.layers.Dense(5) içindeki "l1" . tf.train.Checkpoint , tf.train.Checkpoint(step=...) "step" olduğu gibi, anahtar kelime argüman adlarını kullanır.

Yukarıdaki örnekteki bağımlılık grafiği şuna benzer:

Örnek eğitim döngüsü için bağımlılık grafiğinin görselleştirilmesi

Optimize edici kırmızı, normal değişkenler mavi ve optimize edici alan değişkenleri turuncu renkte. Örneğin tf.train.Checkpoint temsil eden diğer düğümler siyahtır.

Yuva değişkenleri, optimize edicinin durumunun bir parçasıdır, ancak belirli bir değişken için oluşturulur. Örneğin, yukarıdaki 'm' kenarları, Adam optimize edicinin her değişken için izlediği momentuma karşılık gelir. Yuva değişkenleri, yalnızca değişken ve optimize edicinin her ikisi de, dolayısıyla kesikli kenarlar kaydedilecekse bir kontrol noktasına kaydedilir.

Bir tf.train.Checkpoint nesnesinde restore() 'nin çağrılması, istenen geri tf.train.Checkpoint sıraya tf.train.Checkpoint ve Checkpoint nesnesinden eşleşen bir yol olduğu anda değişken değerlerini geri yükler. Örneğin, yukarıda tanımladığımız modelden sadece önyargıyı, ona ağ ve katman aracılığıyla bir yolu yeniden yapılandırarak yükleyebiliriz.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
[0. 0. 0. 0. 0.]
[3.5548685 2.8931093 2.3509905 3.5525272 4.017799 ]

Bu yeni nesnelerin bağımlılık grafiği, yukarıda yazdığımız daha büyük kontrol noktasının çok daha küçük bir alt grafiğidir. Yalnızca önyargıyı ve tf.train.Checkpoint denetim noktalarını tf.train.Checkpoint kullandığı bir kaydetme sayacını içerir.

Sapma değişkeni için bir alt grafiğin görselleştirilmesi

restore() , isteğe bağlı iddialara sahip bir durum nesnesi döndürür. Yeni Checkpoint status.assert_existing_objects_matched() oluşturduğumuz tüm nesneler geri yüklendi, bu nedenle status.assert_existing_objects_matched() geçer.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fea0c3c3860>

Kontrol noktasında, katmanın çekirdeği ve optimize edicinin değişkenleri de dahil olmak üzere eşleşmeyen birçok nesne vardır. status.assert_consumed() yalnızca denetim noktası ve program tam olarak eşleştiğinde geçer ve burada bir istisna atar.

Gecikmiş restorasyonlar

TensorFlow'daki Layer nesneleri, giriş şekilleri mevcut olduğunda değişkenlerin oluşturulmasını ilk çağrılarına kadar geciktirebilir. Örneğin, Dense katman çekirdeğinin şekli hem katmanın giriş hem de çıkış şekillerine bağlıdır ve bu nedenle yapıcı argümanı olarak gereken çıktı şekli, değişkeni kendi başına oluşturmak için yeterli bilgi değildir. Bir Layer çağırmak aynı zamanda değişkenin değerini de okuduğundan, değişkenin oluşturulması ile ilk kullanımı arasında bir geri yükleme gerçekleşmelidir.

Bu deyimi desteklemek için tf.train.Checkpoint kuyrukları, henüz eşleşen bir değişkeni olmayanları geri yükler.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.453001  4.6668463 4.9372597 4.90143   4.9549575]]

Kontrol noktalarını manuel olarak incelemek

tf.train.list_variables , bir kontrol noktasındaki değişkenlerin kontrol noktası anahtarlarını ve şekillerini listeler. Denetim noktası anahtarları, yukarıda görüntülenen grafikteki yollardır.

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

Liste ve sözlük takibi

self.l1 = tf.keras.layers.Dense(5) gibi doğrudan öznitelik atamalarında olduğu gibi, özniteliklere listeler ve sözlükler atamak içeriklerini izler.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Listeler ve sözlükler için sarmalayıcı nesneleri fark edebilirsiniz. Bu sarmalayıcılar, temeldeki veri yapılarının kontrol noktalı sürümleridir. Öznitelik tabanlı yüklemede olduğu gibi, bu sarmalayıcılar bir değişkenin değerini konteynere eklenir eklenmez geri yükler.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Aynı izleme, otomatik olarak tf.keras.Model alt sınıflarına uygulanır ve örneğin katman listelerini izlemek için kullanılabilir.

Estimator ile nesne tabanlı kontrol noktalarını kaydetme

Tahminci kılavuzuna bakın.

Tahmin ediciler varsayılan olarak kontrol noktalarını önceki bölümlerde açıklanan nesne grafiği yerine değişken adlarıyla kaydeder. tf.train.Checkpoint isme dayalı denetim noktalarını kabul eder, ancak bir modelin parçalarını Tahmincinin model_fn dışına taşırken değişken adları değişebilir. Nesne tabanlı kontrol noktalarının kaydedilmesi, bir Tahmincide bir model eğitmeyi ve daha sonra onu bir modelin dışında kullanmayı kolaylaştırır.

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.388644, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 34.98601.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fea648fbf60>

tf.train.Checkpoint daha sonra Estimator'ın kontrol noktalarını model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Özet

TensorFlow nesneleri, kullandıkları değişkenlerin değerlerini kaydetmek ve geri yüklemek için kolay bir otomatik mekanizma sağlar.