Pos pemeriksaan pelatihan

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Ungkapan "Menyimpan model TensorFlow" biasanya berarti salah satu dari dua hal:

  1. Pos pemeriksaan, OR
  2. Model Tersimpan.

Pos pemeriksaan menangkap nilai yang tepat dari semua parameter ( objek tf.Variable ) yang digunakan oleh model. Pos pemeriksaan tidak berisi deskripsi komputasi yang ditentukan oleh model dan dengan demikian biasanya hanya berguna ketika kode sumber yang akan menggunakan nilai parameter yang disimpan tersedia.

Format SavedModel di sisi lain mencakup deskripsi serial dari perhitungan yang ditentukan oleh model di samping nilai parameter (pos pemeriksaan). Model dalam format ini tidak bergantung pada kode sumber yang membuat model. Oleh karena itu, mereka cocok untuk diterapkan melalui TensorFlow Serving, TensorFlow Lite, TensorFlow.js, atau program dalam bahasa pemrograman lain (API C, C++, Java, Go, Rust, C# dll. TensorFlow).

Panduan ini mencakup API untuk menulis dan membaca pos pemeriksaan.

Mempersiapkan

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Menyimpan dari API pelatihan tf.keras

Lihat panduan tf.keras tentang menyimpan dan memulihkan .

tf.keras.Model.save_weights menyimpan pos pemeriksaan TensorFlow.

net.save_weights('easy_checkpoint')

Menulis pos pemeriksaan

Status persisten model TensorFlow disimpan di objek tf.Variable . Ini dapat dibuat secara langsung, tetapi sering dibuat melalui API tingkat tinggi seperti tf.keras.layers atau tf.keras.Model .

Cara termudah untuk mengelola variabel adalah dengan melampirkannya ke objek Python, lalu mereferensikan objek tersebut.

Subkelas tf.train.Checkpoint , tf.keras.layers.Layer , dan tf.keras.Model secara otomatis melacak variabel yang ditetapkan ke atributnya. Contoh berikut membangun model linier sederhana, kemudian menulis pos pemeriksaan yang berisi nilai untuk semua variabel model.

Anda dapat dengan mudah menyimpan model-checkpoint dengan Model.save_weights .

Pos pemeriksaan manual

Mempersiapkan

Untuk membantu mendemonstrasikan semua fitur tf.train.Checkpoint , tentukan kumpulan data mainan dan langkah pengoptimalan:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Buat objek pos pemeriksaan

Gunakan objek tf.train.Checkpoint untuk membuat pos pemeriksaan secara manual, di mana objek yang ingin Anda periksa ditetapkan sebagai atribut pada objek.

Sebuah tf.train.CheckpointManager juga dapat membantu untuk mengelola beberapa pos pemeriksaan.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Latih dan pos pemeriksaan model

Loop pelatihan berikut membuat instance model dan pengoptimal, lalu mengumpulkannya menjadi objek tf.train.Checkpoint . Ini memanggil langkah pelatihan dalam satu lingkaran pada setiap kumpulan data, dan secara berkala menulis pos pemeriksaan ke disk.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

Pulihkan dan lanjutkan pelatihan

Setelah siklus pelatihan pertama, Anda dapat melewati model dan manajer baru, tetapi melanjutkan pelatihan tepat di tempat terakhir Anda tinggalkan:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

Objek tf.train.CheckpointManager menghapus pos pemeriksaan lama. Di atasnya dikonfigurasi untuk hanya menyimpan tiga pos pemeriksaan terbaru.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Jalur ini, misalnya './tf_ckpts/ckpt-10' , bukan file di disk. Sebaliknya mereka adalah awalan untuk file index dan satu atau lebih file data yang berisi nilai variabel. Awalan ini dikelompokkan bersama dalam satu file checkpoint ( './tf_ckpts/checkpoint' ) tempat CheckpointManager menyimpan statusnya.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mekanika pemuatan

TensorFlow mencocokkan variabel dengan nilai checkpoint dengan melintasi grafik berarah dengan tepi bernama, mulai dari objek yang dimuat. Nama tepi biasanya berasal dari nama atribut di objek, misalnya "l1" di self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint menggunakan nama argumen kata kuncinya, seperti pada "step" di tf.train.Checkpoint(step=...) .

Grafik ketergantungan dari contoh di atas terlihat seperti ini:

Visualisasi grafik ketergantungan untuk contoh pelatihan loop

Pengoptimal berwarna merah, variabel reguler berwarna biru, dan variabel slot pengoptimal berwarna oranye. Node lain—misalnya, mewakili tf.train.Checkpoint —berwarna hitam.

Variabel slot adalah bagian dari status pengoptimal, tetapi dibuat untuk variabel tertentu. Misalnya, tepi 'm' di atas sesuai dengan momentum, yang dilacak oleh pengoptimal Adam untuk setiap variabel. Variabel slot hanya disimpan di pos pemeriksaan jika variabel dan pengoptimal keduanya akan disimpan, sehingga tepi putus-putus.

Memanggil restore pada objek tf.train.Checkpoint mengantre pemulihan yang diminta, memulihkan nilai variabel segera setelah ada jalur yang cocok dari objek Checkpoint . Misalnya, Anda dapat memuat hanya bias dari model yang Anda definisikan di atas dengan merekonstruksi satu jalur ke sana melalui jaringan dan lapisan.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

Grafik ketergantungan untuk objek baru ini adalah subgraf yang jauh lebih kecil dari pos pemeriksaan yang lebih besar yang Anda tulis di atas. Ini hanya mencakup bias dan penghitung simpan yang digunakan tf.train.Checkpoint untuk memberi nomor pos pemeriksaan.

Visualisasi subgraf untuk variabel bias

restore mengembalikan objek status, yang memiliki pernyataan opsional. Semua objek yang dibuat di Checkpoint baru telah dipulihkan, jadi status.assert_existing_objects_matched lolos.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

Ada banyak objek di pos pemeriksaan yang belum cocok, termasuk kernel layer dan variabel pengoptimal. status.assert_consumed hanya lolos jika pos pemeriksaan dan program sama persis, dan akan mengeluarkan pengecualian di sini.

Restorasi tertunda

Objek Layer di TensorFlow dapat menunda pembuatan variabel ke panggilan pertama, saat bentuk input tersedia. Misalnya, bentuk kernel layer Dense bergantung pada bentuk input dan output layer, sehingga bentuk output yang diperlukan sebagai argumen konstruktor bukanlah informasi yang cukup untuk membuat variabel sendiri. Karena memanggil Layer juga membaca nilai variabel, pemulihan harus terjadi antara pembuatan variabel dan penggunaan pertama.

Untuk mendukung idiom ini, tf.train.Checkpoint yang belum memiliki variabel yang cocok.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

Memeriksa pos pemeriksaan secara manual

tf.train.load_checkpoint mengembalikan CheckpointReader yang memberikan akses tingkat rendah ke konten pos pemeriksaan. Ini berisi pemetaan dari setiap kunci variabel, ke bentuk dan tipe d untuk setiap variabel di pos pemeriksaan. Kunci variabel adalah jalur objeknya, seperti pada grafik yang ditampilkan di atas.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Jadi jika Anda tertarik dengan nilai net.l1.kernel Anda bisa mendapatkan nilainya dengan kode berikut:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Ini juga menyediakan metode get_tensor yang memungkinkan Anda memeriksa nilai variabel:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

Pelacakan objek

Pos pemeriksaan menyimpan dan memulihkan nilai objek tf.Variable dengan "melacak" setiap variabel atau objek yang dapat dilacak yang ditetapkan dalam salah satu atributnya. Saat menjalankan penyimpanan, variabel dikumpulkan secara rekursif dari semua objek terlacak yang dapat dijangkau.

Seperti penetapan atribut langsung seperti self.l1 = tf.keras.layers.Dense(5) , menetapkan daftar dan kamus ke atribut akan melacak isinya.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Anda mungkin melihat objek pembungkus untuk daftar dan kamus. Pembungkus ini adalah versi yang dapat diperiksa dari struktur data yang mendasarinya. Sama seperti pemuatan berbasis atribut, pembungkus ini mengembalikan nilai variabel segera setelah ditambahkan ke penampung.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Objek yang dapat dilacak termasuk tf.train.Checkpoint , tf.Module dan subkelasnya (misalnya keras.layers.Layer dan keras.Model ), dan wadah Python yang dikenali:

  • dict (dan collections.OrderedDict )
  • list
  • tuple (dan collections.namedtuple , typing.NamedTuple )

Jenis penampung lainnya tidak didukung , termasuk:

  • collections.defaultdict
  • set

Semua objek Python lainnya diabaikan , termasuk:

  • int
  • string
  • float

Ringkasan

Objek TensorFlow menyediakan mekanisme otomatis yang mudah untuk menyimpan dan memulihkan nilai variabel yang mereka gunakan.