Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Pos pemeriksaan pelatihan

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Frasa "Menyimpan model TensorFlow" biasanya berarti salah satu dari dua hal:

  1. Pos pemeriksaan, ATAU
  2. SavedModel.

Pos pemeriksaan menangkap nilai pasti dari semua parameter ( tf.Variable Objek tf.Variable ) yang digunakan oleh model. Checkpoint tidak berisi deskripsi komputasi apa pun yang ditentukan oleh model dan karenanya biasanya hanya berguna jika kode sumber yang akan menggunakan nilai parameter yang disimpan tersedia.

Format SavedModel di sisi lain menyertakan deskripsi serialisasi dari komputasi yang ditentukan oleh model selain nilai parameter (checkpoint). Model dalam format ini tidak bergantung pada kode sumber yang membuat model. Karenanya, mereka cocok untuk penerapan melalui TensorFlow Serving, TensorFlow Lite, TensorFlow.js, atau program dalam bahasa pemrograman lain (API TensorFlow C, C ++, Java, Go, Rust, C # dll.).

Panduan ini mencakup API untuk menulis dan membaca pos pemeriksaan.

Mendirikan

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Menyimpan dari API pelatihan tf.keras

Lihat panduan tf.keras tentang menyimpan dan memulihkan.

tf.keras.Model.save_weights menyimpan checkpoint TensorFlow.

net.save_weights('easy_checkpoint')

Menulis pos pemeriksaan

Status persisten model TensorFlow disimpan dalam objek tf.Variable . Ini dapat dibuat secara langsung, tetapi sering kali dibuat melalui API tingkat tinggi sepertitf.keras.layers atau tf.keras.Model .

Cara termudah untuk mengelola variabel adalah dengan melampirkannya ke objek Python, lalu mereferensikan objek tersebut.

Subclass dari tf.train.Checkpoint , tf.keras.layers.Layer , dan tf.keras.Model secara otomatis melacak variabel yang ditetapkan ke atributnya. Contoh berikut membuat model linier sederhana, lalu menulis checkpoint yang berisi nilai untuk semua variabel model.

Anda dapat dengan mudah menyimpan model-checkpoint dengan Model.save_weights

Pemeriksaan manual

Mendirikan

Untuk membantu mendemonstrasikan semua fitur tf.train.Checkpoint tentukan tf.train.Checkpoint data mainan dan langkah pengoptimalan:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Buat objek pos pemeriksaan

Untuk membuat pos pemeriksaan secara manual, Anda memerlukan objek tf.train.Checkpoint . Di mana objek yang ingin Anda periksa ditetapkan sebagai atribut pada objek.

Sebuah tf.train.CheckpointManager juga dapat berguna untuk mengelola beberapa checkpoint.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Latih dan periksa modelnya

Loop pelatihan berikut membuat instance model dan pengoptimal, lalu mengumpulkannya ke dalam objek tf.train.Checkpoint . Ini memanggil langkah pelatihan dalam satu lingkaran pada setiap kumpulan data, dan secara berkala menulis pos pemeriksaan ke disk.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 27.29
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 20.70
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 14.14
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 7.68
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 2.07

Pulihkan dan lanjutkan pelatihan

Setelah yang pertama, Anda dapat meneruskan model dan manajer baru, tetapi mengambil pelatihan tepat di tempat Anda tinggalkan:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.04
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.81
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.67
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.22

Objek tf.train.CheckpointManager menghapus checkpoint lama. Di atasnya dikonfigurasi untuk menyimpan hanya tiga pos pemeriksaan terbaru.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

'./tf_ckpts/ckpt-10' ini, misalnya './tf_ckpts/ckpt-10' , bukanlah file pada disk. Sebaliknya mereka adalah awalan untuk file index dan satu atau lebih file data yang berisi nilai variabel. Awalan ini dikelompokkan bersama dalam satu file checkpoint ( './tf_ckpts/checkpoint' ) di mana CheckpointManager menyimpan statusnya.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mekanika pemuatan

TensorFlow mencocokkan variabel dengan nilai checkpoint dengan melintasi grafik berarah dengan tepian bernama, mulai dari objek yang dimuat. Nama tepi biasanya berasal dari nama atribut dalam objek, misalnya "l1" di self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint menggunakan nama argumen kata tf.train.Checkpoint , seperti dalam "step" di tf.train.Checkpoint(step=...) .

Grafik ketergantungan dari contoh di atas terlihat seperti ini:

Visualisasi grafik ketergantungan untuk loop pelatihan contoh

Dengan pengoptimal berwarna merah, variabel reguler berwarna biru, dan variabel slot pengoptimal berwarna oranye. Node lainnya, misalnya mewakili tf.train.Checkpoint , berwarna hitam.

Variabel slot adalah bagian dari status pengoptimal, tetapi dibuat untuk variabel tertentu. Misalnya tepi 'm' atas sesuai dengan momentum, yang dilacak oleh pengoptimal Adam untuk setiap variabel. Variabel slot hanya disimpan di checkpoint jika variabel dan pengoptimal akan disimpan, dengan demikian tepi putus-putus.

Memanggil restore() pada objek tf.train.Checkpoint mengantri restorasi yang diminta, memulihkan nilai variabel segera setelah ada jalur yang cocok dari objek Checkpoint . Misalnya, Anda dapat memuat hanya bias dari model yang kami definisikan di atas dengan merekonstruksi satu jalur ke model tersebut melalui jaringan dan lapisan.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[1.9204693 3.3416245 2.7418654 2.9938312 4.2179084]

Grafik ketergantungan untuk objek baru ini adalah subgraf yang jauh lebih kecil dari pos pemeriksaan yang lebih besar yang kami tulis di atas. Ini hanya mencakup bias dan penghitung tf.train.Checkpoint digunakan tf.train.Checkpoint untuk tf.train.Checkpoint pos pemeriksaan.

Visualisasi subgraf untuk variabel bias

restore() mengembalikan objek status, yang memiliki pernyataan opsional. Semua objek yang dibuat di Checkpoint baru telah dipulihkan, sehingga status.assert_existing_objects_matched() lolos.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fbab44668d0>

Ada banyak objek di pos pemeriksaan yang belum cocok, termasuk kernel lapisan dan variabel pengoptimal. status.assert_consumed() hanya lolos jika checkpoint dan program sama persis, dan akan memunculkan pengecualian di sini.

Restorasi tertunda

Objek Layer di TensorFlow dapat menunda pembuatan variabel ke panggilan pertamanya, jika bentuk masukan tersedia. Misalnya, bentuk kernel layer Dense bergantung pada bentuk input dan output layer, sehingga bentuk output yang diperlukan sebagai argumen konstruktor tidak memiliki informasi yang cukup untuk membuat variabel sendiri. Karena memanggil Layer juga membaca nilai variabel, pemulihan harus terjadi antara pembuatan variabel dan penggunaan pertama.

Untuk mendukung idiom ini, tf.train.Checkpoint antrian tf.train.Checkpoint yang belum memiliki variabel yang cocok.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6967545 4.656775  4.8799973 4.961913  4.893354 ]]

Menginspeksi pos pemeriksaan secara manual

tf.train.load_checkpoint mengembalikan CheckpointReader yang memberikan akses tingkat rendah ke konten checkpoint. Ini berisi pemetaan dari setiap kunci vartiable, ke bentuk dan dtype untuk setiap variabel di pos pemeriksaan. Kunci variabel adalah jalur objeknya, seperti pada grafik yang ditampilkan di atas.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Jadi jika Anda tertarik dengan nilai net.l1.kernel Anda bisa mendapatkan nilainya dengan kode berikut:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Ini juga menyediakan metode get_tensor yang memungkinkan Anda memeriksa nilai variabel:

reader.get_tensor(key)
array([[4.6967545, 4.656775 , 4.8799973, 4.961913 , 4.893354 ]],
      dtype=float32)

Daftar dan pelacakan kamus

Seperti dengan penetapan atribut langsung seperti self.l1 = tf.keras.layers.Dense(5) , menetapkan daftar dan kamus ke atribut akan melacak isinya.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
.dll

Anda mungkin memperhatikan objek pembungkus untuk daftar dan kamus. Wrapper ini adalah versi yang dapat diperiksa dari struktur data yang mendasarinya. Sama seperti pemuatan berbasis atribut, pembungkus ini memulihkan nilai variabel segera setelah ditambahkan ke penampung.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Pelacakan yang sama diterapkan secara otomatis ke subkelas tf.keras.Model , dan dapat digunakan misalnya untuk melacak daftar lapisan.

Ringkasan

Objek TensorFlow menyediakan mekanisme otomatis yang mudah untuk menyimpan dan memulihkan nilai variabel yang mereka gunakan.