Lưu ngày! Google I / O hoạt động trở lại từ ngày 18 đến 20 tháng 5 Đăng ký ngay
Trang này được dịch bởi Cloud Translation API.
Switch to English

Các điểm kiểm tra đào tạo

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Cụm từ "Lưu mô hình TensorFlow" thường có nghĩa là một trong hai điều:

  1. Trạm kiểm soát, HOẶC
  2. SavedModel.

Các trạm kiểm soát nắm bắt giá trị chính xác của tất cả các tham số (đối tượng tf.Variable ) được sử dụng bởi một mô hình. Các điểm kiểm tra không chứa bất kỳ mô tả nào về tính toán được xác định bởi mô hình và do đó thường chỉ hữu ích khi có sẵn mã nguồn sử dụng các giá trị tham số đã lưu.

Mặt khác, định dạng SavedModel bao gồm mô tả tuần tự của phép tính được xác định bởi mô hình ngoài các giá trị tham số (điểm kiểm tra). Các mô hình ở định dạng này độc lập với mã nguồn đã tạo ra mô hình. Do đó, chúng phù hợp để triển khai thông qua TensorFlow Serving, TensorFlow Lite, TensorFlow.js hoặc các chương trình bằng các ngôn ngữ lập trình khác (C, C ++, Java, Go, Rust, C #, v.v. TensorFlow API).

Hướng dẫn này bao gồm các API để ghi và đọc các điểm kiểm tra.

Thiết lập

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Tiết kiệm từ các API đào tạo tf.keras

Xem hướng dẫn của tf.keras về cách lưu và khôi phục.

tf.keras.Model.save_weights lưu một điểm kiểm tra TensorFlow.

net.save_weights('easy_checkpoint')

Viết các trạm kiểm soát

Trạng thái liên tục của một mô hình TensorFlow được lưu trữ trong các đối tượng tf.Variable . Chúng có thể được tạo trực tiếp, nhưng thường được tạo thông qua các API cấp cao nhưtf.keras.layers hoặc tf.keras.Model .

Cách dễ nhất để quản lý các biến là gắn chúng vào các đối tượng Python, sau đó tham chiếu đến các đối tượng đó.

Các lớp con của tf.train.Checkpoint , tf.keras.layers.Layertf.keras.Model tự động theo dõi các biến được gán cho các thuộc tính của chúng. Ví dụ sau đây xây dựng một mô hình tuyến tính đơn giản, sau đó viết các điểm kiểm tra chứa các giá trị cho tất cả các biến của mô hình.

Bạn có thể dễ dàng lưu điểm kiểm tra mô hình với Model.save_weights .

Kiểm tra thủ công

Thiết lập

Để giúp chứng minh tất cả các tính năng của tf.train.Checkpoint , hãy xác định tập dữ liệu đồ chơi và bước tối ưu hóa:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Tạo các đối tượng điểm kiểm tra

Sử dụng đối tượng tf.train.Checkpoint để tạo một điểm kiểm tra theo cách thủ công, trong đó các đối tượng bạn muốn điểm kiểm tra được đặt làm thuộc tính trên đối tượng.

Một tf.train.CheckpointManager cũng có thể hữu ích để quản lý nhiều điểm kiểm tra.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Huấn luyện và kiểm tra mô hình

Vòng lặp đào tạo sau đây tạo một phiên bản của mô hình và của một trình tối ưu hóa, sau đó tập hợp chúng thành một đối tượng tf.train.Checkpoint . Nó gọi bước huấn luyện trong một vòng lặp trên mỗi lô dữ liệu và định kỳ ghi các điểm kiểm tra vào đĩa.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 30.42
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.83
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 17.27
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.81
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.74

Khôi phục và tiếp tục đào tạo

Sau chu kỳ đào tạo đầu tiên, bạn có thể vượt qua một mô hình và người quản lý mới, nhưng hãy tiếp tục đào tạo chính xác nơi bạn đã dừng lại:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.85
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.87
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.71
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.46
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.21

Đối tượng tf.train.CheckpointManager xóa các điểm kiểm tra cũ. Ở trên, nó được định cấu hình để chỉ giữ ba điểm kiểm tra gần đây nhất.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Những đường dẫn này, ví dụ: './tf_ckpts/ckpt-10' , không phải là tệp trên đĩa. Thay vào đó, chúng là tiền tố cho một tệp index và một hoặc nhiều tệp dữ liệu chứa các giá trị biến. Các tiền tố này được nhóm lại với nhau trong một tệp checkpoint duy nhất ( './tf_ckpts/checkpoint' ) nơi CheckpointManager lưu trạng thái của nó.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Cơ khí tải

TensorFlow đối sánh các biến với các giá trị đã kiểm tra bằng cách duyệt qua một biểu đồ có hướng với các cạnh được đặt tên, bắt đầu từ đối tượng đang được tải. Tên cạnh thường đến từ tên thuộc tính trong các đối tượng, ví dụ: "l1" trong self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint sử dụng tên đối số từ khóa của nó, như trong "step" trong tf.train.Checkpoint(step=...) .

Biểu đồ phụ thuộc từ ví dụ trên trông như sau:

Hình ảnh hóa biểu đồ phụ thuộc cho vòng lặp đào tạo ví dụ

Trình tối ưu hóa có màu đỏ, các biến thông thường có màu xanh lam và các biến vị trí của trình tối ưu hóa có màu cam. Các nút khác — ví dụ, đại diện cho tf.train.Checkpoint — có màu đen.

Các biến vị trí là một phần của trạng thái của trình tối ưu hóa, nhưng được tạo cho một biến cụ thể. Ví dụ: các cạnh 'm' ở trên tương ứng với động lượng mà trình tối ưu hóa Adam theo dõi cho mỗi biến. Các biến vị trí chỉ được lưu trong một điểm kiểm tra nếu cả biến và trình tối ưu hóa đều được lưu, do đó các cạnh gạch ngang.

Gọi restore trên đối tượng tf.train.Checkpoint xếp hàng các khôi phục được yêu cầu, khôi phục các giá trị biến ngay khi có một đường dẫn phù hợp từ đối tượng Checkpoint . Ví dụ: bạn có thể tải chỉ độ lệch từ mô hình bạn đã xác định ở trên bằng cách tạo lại một đường dẫn đến nó thông qua mạng và lớp.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.831489  3.7156947 2.5892444 3.8669944 4.749503 ]

Biểu đồ phụ thuộc cho các đối tượng mới này là một đồ thị con nhỏ hơn nhiều của trạm kiểm soát lớn hơn mà bạn đã viết ở trên. Nó chỉ bao gồm thiên vị và một bộ đếm lưu mà tf.train.Checkpoint sử dụng để đánh số các điểm kiểm tra.

Hình dung một đồ thị con cho biến thiên vị

restore trả về một đối tượng trạng thái, đối tượng này có các xác nhận tùy chọn. Tất cả các đối tượng được tạo trong Checkpoint mới đã được khôi phục, do đó, status.assert_existing_objects_matched vượt qua.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1644447b70>

Có nhiều đối tượng trong trạm kiểm soát chưa khớp, bao gồm nhân của lớp và các biến của trình tối ưu hóa. status.assert_consumed chỉ vượt qua nếu điểm kiểm tra và chương trình khớp chính xác và sẽ đưa ra một ngoại lệ ở đây.

Phục hồi bị trì hoãn

Layer đối tượng Layer trong TensorFlow có thể trì hoãn việc tạo các biến cho lần gọi đầu tiên của chúng, khi các hình dạng đầu vào có sẵn. Ví dụ: hình dạng của nhân của lớp Dense phụ thuộc vào cả hình dạng đầu vào và đầu ra của lớp, và do đó hình dạng đầu ra được yêu cầu như một đối số của hàm tạo không đủ thông tin để tự tạo biến. Vì việc gọi một Layer cũng đọc giá trị của biến, nên việc khôi phục phải xảy ra giữa quá trình tạo biến và lần sử dụng đầu tiên của nó.

Để hỗ trợ thành ngữ này, các hàng đợi tf.train.Checkpoint khôi phục các hàng đợi chưa có biến phù hợp.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5719748 4.6099544 4.931875  4.836442  4.8496275]]

Kiểm tra các trạm kiểm soát theo cách thủ công

tf.train.load_checkpoint trả về CheckpointReader cấp quyền truy cập cấp thấp hơn vào nội dung điểm kiểm tra. Nó chứa các ánh xạ từ khóa của mỗi vartiable đến hình dạng và kiểu cho mỗi biến trong trạm kiểm soát. Chìa khóa của một biến là đường dẫn đối tượng của nó, giống như trong đồ thị được hiển thị ở trên.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Vì vậy, nếu bạn quan tâm đến giá trị của net.l1.kernel bạn có thể lấy giá trị bằng đoạn mã sau:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Nó cũng cung cấp một phương thức get_tensor cho phép bạn kiểm tra giá trị của một biến:

reader.get_tensor(key)
array([[4.5719748, 4.6099544, 4.931875 , 4.836442 , 4.8496275]],
      dtype=float32)

Theo dõi danh sách và từ điển

Như với các phép gán thuộc tính trực tiếp như self.l1 = tf.keras.layers.Dense(5) , việc gán danh sách và từ điển cho các thuộc tính sẽ theo dõi nội dung của chúng.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Bạn có thể nhận thấy các đối tượng bao bọc cho danh sách và từ điển. Các trình bao bọc này là các phiên bản có thể kiểm tra của cấu trúc dữ liệu cơ bản. Cũng giống như cách tải dựa trên thuộc tính, các trình bao bọc này khôi phục giá trị của một biến ngay sau khi nó được thêm vào vùng chứa.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Việc theo dõi tương tự được tự động áp dụng cho các lớp con của tf.keras.Model và có thể được sử dụng chẳng hạn như để theo dõi danh sách các lớp.

Tóm lược

Các đối tượng TensorFlow cung cấp một cơ chế tự động dễ dàng để lưu và khôi phục các giá trị của các biến mà chúng sử dụng.