Trang này được dịch bởi Cloud Translation API.
Switch to English

Các điểm kiểm tra đào tạo

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ tay

Cụm từ "Lưu mô hình TensorFlow" thường có nghĩa là một trong hai điều:

  1. Trạm kiểm soát, HOẶC
  2. SavedModel.

Các điểm kiểm tra nắm bắt giá trị chính xác của tất cả các tham số (đối tượng tf.Variable ) được sử dụng bởi một mô hình. Các điểm kiểm tra không chứa bất kỳ mô tả nào về tính toán được xác định bởi mô hình và do đó thường chỉ hữu ích khi có sẵn mã nguồn sử dụng các giá trị tham số đã lưu.

Mặt khác, định dạng SavedModel bao gồm một mô tả tuần tự của phép tính được xác định bởi mô hình ngoài các giá trị tham số (điểm kiểm tra). Mô hình ở định dạng này độc lập với mã nguồn đã tạo ra mô hình. Do đó, chúng phù hợp để triển khai thông qua TensorFlow Serving, TensorFlow Lite, TensorFlow.js hoặc các chương trình bằng các ngôn ngữ lập trình khác (C, C ++, Java, Go, Rust, C #, v.v. TensorFlow API).

Hướng dẫn này bao gồm các API để ghi và đọc các điểm kiểm tra.

Thiết lập

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Tiết kiệm từ các API đào tạo tf.keras

Xem hướng dẫn tf.keras về cách lưu và khôi phục.

tf.keras.Model.save_weights lưu một điểm kiểm tra TensorFlow.

net.save_weights('easy_checkpoint')

Viết điểm kiểm tra

Trạng thái liên tục của mô hình TensorFlow được lưu trữ trong các đối tượng tf.Variable . Chúng có thể được tạo trực tiếp, nhưng thường được tạo thông qua các API cấp cao như tf.keras.layers hoặc tf.keras.Model .

Cách dễ nhất để quản lý các biến là gắn chúng vào các đối tượng Python, sau đó tham chiếu đến các đối tượng đó.

Các lớp con của tf.train.Checkpoint , tf.keras.layers.Layertf.keras.Model tự động theo dõi các biến được gán cho thuộc tính của chúng. Ví dụ sau đây xây dựng một mô hình tuyến tính đơn giản, sau đó viết các điểm kiểm tra chứa các giá trị cho tất cả các biến của mô hình.

Bạn có thể dễ dàng lưu điểm kiểm tra mô hình với Model.save_weights

Kiểm tra thủ công

Thiết lập

Để giúp chứng minh tất cả các tính năng của tf.train.Checkpoint xác định một tập dữ liệu đồ chơi và bước tối ưu hóa:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Tạo các đối tượng điểm kiểm tra

Để tạo một trạm kiểm soát theo cách thủ công, bạn sẽ cần một đối tượng tf.train.Checkpoint . Nơi các đối tượng bạn muốn kiểm tra được đặt làm thuộc tính trên đối tượng.

Một tf.train.CheckpointManager cũng có thể hữu ích để quản lý nhiều điểm kiểm tra.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Huấn luyện và kiểm tra mô hình

Vòng lặp đào tạo sau đây tạo ra một thể hiện của mô hình và của trình tối ưu hóa, sau đó tập hợp chúng thành một đối tượng tf.train.Checkpoint . Nó gọi bước huấn luyện trong một vòng lặp trên mỗi lô dữ liệu và định kỳ ghi các điểm kiểm tra vào đĩa.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 28.15
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 21.56
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.00
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 8.52
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.25

Khôi phục và tiếp tục đào tạo

Sau lần đầu tiên, bạn có thể vượt qua một mô hình và người quản lý mới, nhưng đào tạo tiếp nhận chính xác nơi bạn đã dừng lại:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.76
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.65
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.51
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.25

Đối tượng tf.train.CheckpointManager xóa các điểm kiểm tra cũ. Ở trên, nó được định cấu hình để chỉ giữ ba điểm kiểm tra gần đây nhất.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Các đường dẫn này, ví dụ: './tf_ckpts/ckpt-10' , không phải là tệp trên đĩa. Thay vào đó, chúng là tiền tố cho một tệp index và một hoặc nhiều tệp dữ liệu chứa các giá trị biến. Các tiền tố này được nhóm lại với nhau trong một tệp checkpoint ( './tf_ckpts/checkpoint' ) nơi CheckpointManager lưu trạng thái của nó.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Cơ khí tải

TensorFlow đối sánh các biến với các giá trị đã kiểm tra bằng cách duyệt qua đồ thị có hướng với các cạnh được đặt tên, bắt đầu từ đối tượng đang được tải. Tên cạnh thường đến từ tên thuộc tính trong các đối tượng, ví dụ như "l1" trong self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint sử dụng tên đối số từ khóa của nó, như trong "step" trong tf.train.Checkpoint(step=...) .

Biểu đồ phụ thuộc từ ví dụ trên trông như sau:

Hình ảnh hóa biểu đồ phụ thuộc cho vòng lặp đào tạo ví dụ

Với trình tối ưu hóa có màu đỏ, các biến thông thường có màu xanh lam và các biến vị trí của trình tối ưu hóa có màu cam. Các nút khác, ví dụ đại diện cho điểm tf.train.Checkpoint , có màu đen.

Biến vị trí là một phần của trạng thái của trình tối ưu hóa, nhưng được tạo cho một biến cụ thể. Ví dụ: các cạnh 'm' ở trên tương ứng với động lượng, mà trình tối ưu hóa Adam theo dõi cho mỗi biến. Các biến vị trí chỉ được lưu trong một điểm kiểm tra nếu cả biến và trình tối ưu hóa đều được lưu, do đó các cạnh gạch ngang.

Gọi restore() trên đối tượng tf.train.Checkpoint xếp hàng các khôi phục được yêu cầu, khôi phục các giá trị biến ngay khi có một đường dẫn phù hợp từ đối tượng Checkpoint . Ví dụ, chúng ta có thể tải chỉ độ lệch từ mô hình mà chúng ta đã xác định ở trên bằng cách tạo lại một đường dẫn đến nó thông qua mạng và lớp.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
[0. 0. 0. 0. 0.]
[3.5548685 2.8931093 2.3509905 3.5525272 4.017799 ]

Đồ thị phụ thuộc cho các đối tượng mới này là một đồ thị con nhỏ hơn nhiều của trạm kiểm soát lớn hơn mà chúng tôi đã viết ở trên. Nó chỉ bao gồm bias và một bộ đếm lưu mà tf.train.Checkpoint sử dụng để đánh số các điểm kiểm tra.

Hình dung một đồ thị con cho biến thiên vị

restore() trả về một đối tượng trạng thái, có các xác nhận tùy chọn. Tất cả các đối tượng mà chúng tôi đã tạo trong Checkpoint mới của chúng tôi đã được khôi phục, vì vậy status.assert_existing_objects_matched() vượt qua.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fea0c3c3860>

Có nhiều đối tượng trong trạm kiểm soát chưa khớp, bao gồm hạt nhân của lớp và các biến của trình tối ưu hóa. status.assert_consumed() chỉ vượt qua nếu điểm kiểm tra và chương trình khớp chính xác và sẽ đưa ra một ngoại lệ ở đây.

Phục hồi bị trì hoãn

Layer đối tượng Layer trong TensorFlow có thể trì hoãn việc tạo các biến cho lần gọi đầu tiên của chúng, khi các hình dạng đầu vào có sẵn. Ví dụ: hình dạng của nhân của lớp Dense phụ thuộc vào cả hình dạng đầu vào và đầu ra của lớp, và do đó hình dạng đầu ra được yêu cầu như một đối số của hàm tạo không đủ thông tin để tự tạo biến. Vì việc gọi một Layer cũng đọc giá trị của biến, nên việc khôi phục phải xảy ra giữa lần tạo biến và lần sử dụng đầu tiên.

Để hỗ trợ thành ngữ này, các hàng đợi tf.train.Checkpoint sẽ khôi phục các hàng đợi chưa có biến phù hợp.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.453001  4.6668463 4.9372597 4.90143   4.9549575]]

Kiểm tra các trạm kiểm soát theo cách thủ công

tf.train.list_variables liệt kê các khóa điểm kiểm tra và hình dạng của các biến trong một điểm kiểm tra. Các khóa điểm kiểm tra là các đường dẫn trong biểu đồ được hiển thị ở trên.

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

Theo dõi danh sách và từ điển

Như với các phép gán thuộc tính trực tiếp như self.l1 = tf.keras.layers.Dense(5) , việc gán danh sách và từ điển cho các thuộc tính sẽ theo dõi nội dung của chúng.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Bạn có thể nhận thấy các đối tượng bao bọc cho danh sách và từ điển. Các trình bao bọc này là các phiên bản có thể kiểm tra của cấu trúc dữ liệu cơ bản. Cũng giống như tải dựa trên thuộc tính, các trình bao bọc này khôi phục giá trị của một biến ngay sau khi được thêm vào vùng chứa.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Theo dõi tương tự được tự động áp dụng cho các lớp con của tf.keras.Model và có thể được sử dụng chẳng hạn để theo dõi danh sách các lớp.

Lưu các điểm kiểm tra dựa trên đối tượng với Công cụ ước tính

Xem hướng dẫn Công cụ ước tính .

Theo mặc định, công cụ ước tính lưu các điểm kiểm tra với tên biến thay vì biểu đồ đối tượng được mô tả trong các phần trước. tf.train.Checkpoint sẽ chấp nhận các điểm kiểm tra dựa trên tên, nhưng tên biến có thể thay đổi khi di chuyển các phần của mô hình bên ngoài model_fn của Công cụ ước tính. Việc lưu các điểm kiểm tra dựa trên đối tượng giúp dễ dàng đào tạo mô hình bên trong Công cụ ước tính và sau đó sử dụng nó bên ngoài một mô hình.

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.388644, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 34.98601.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fea648fbf60>

tf.train.Checkpoint sau đó có thể tải các điểm kiểm tra của Công cụ ước tính từ model_dir của nó.

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Tóm lược

Các đối tượng TensorFlow cung cấp một cơ chế tự động dễ dàng để lưu và khôi phục các giá trị của các biến mà chúng sử dụng.