Các điểm kiểm tra mô hình

Khả năng lưu và khôi phục trạng thái của mô hình là rất quan trọng đối với một số ứng dụng, chẳng hạn như trong học chuyển giao hoặc thực hiện suy luận bằng cách sử dụng các mô hình được huấn luyện trước. Lưu các tham số của mô hình (trọng số, độ lệch, v.v.) trong tệp hoặc thư mục điểm kiểm tra là một cách để thực hiện việc này.

Mô-đun này cung cấp giao diện cấp cao để tải và lưu các điểm kiểm tra định dạng TensorFlow v2 , cũng như các thành phần cấp thấp hơn ghi và đọc từ định dạng tệp này.

Tải và lưu các mô hình đơn giản

Bằng cách tuân thủ giao thức Checkpointable , nhiều mô hình đơn giản có thể được tuần tự hóa thành các điểm kiểm tra mà không cần bất kỳ mã bổ sung nào:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

và sau đó có thể đọc điểm kiểm tra đó bằng cách sử dụng:

try model.readCheckpoint(from: directory, name: "LeNet")

Việc triển khai mặc định này để tải và lưu mô hình sẽ sử dụng sơ đồ đặt tên dựa trên đường dẫn cho mỗi tensor trong mô hình dựa trên tên của các thuộc tính trong cấu trúc mô hình. Ví dụ: trọng số và độ lệch trong tích chập đầu tiên trong mô hình LeNet-5 sẽ được lưu với tên tương ứng conv1/filterconv1/bias . Khi tải, trình đọc điểm kiểm tra sẽ tìm kiếm các tensor có tên này.

Tùy chỉnh tải và lưu mô hình

Nếu bạn muốn có quyền kiểm soát tốt hơn đối với các tensor nào được lưu và tải hoặc đặt tên cho các tensor đó, giao thức Checkpointable sẽ cung cấp một số điểm tùy chỉnh.

Để bỏ qua các thuộc tính trên một số loại nhất định, bạn có thể cung cấp cách triển khai ignoredTensorPaths trên mô hình của mình để trả về một Tập hợp các chuỗi ở dạng Type.property . Ví dụ: để bỏ qua thuộc tính scale trên mỗi lớp Chú ý, bạn có thể trả về ["Attention.scale"] .

Theo mặc định, dấu gạch chéo lên được sử dụng để phân tách từng cấp độ sâu hơn trong mô hình. Bạn có thể tùy chỉnh điều này bằng cách triển khai checkpointSeparator trên mô hình của mình và cung cấp một chuỗi mới để sử dụng cho dấu phân cách này.

Cuối cùng, để có mức độ tùy chỉnh cao nhất trong việc đặt tên tensor, bạn có thể triển khai tensorNameMap và cung cấp một hàm ánh xạ từ tên chuỗi mặc định được tạo cho tensor trong mô hình tới tên chuỗi mong muốn trong điểm kiểm tra. Thông thường nhất, điều này sẽ được sử dụng để tương tác với các điểm kiểm tra được tạo bằng các khung khác, mỗi khung có quy ước đặt tên và cấu trúc mô hình riêng. Chức năng ánh xạ tùy chỉnh mang lại mức độ tùy chỉnh cao nhất cho cách đặt tên cho các tensor này.

Một số hàm trợ giúp tiêu chuẩn được cung cấp, chẳng hạn như CheckpointWriter.identityMap mặc định (chỉ sử dụng tên đường dẫn tensor được tạo tự động cho các điểm kiểm tra) hoặc hàm CheckpointWriter.lookupMap(table:) có thể xây dựng ánh xạ từ từ điển.

Để biết ví dụ về cách thực hiện ánh xạ tùy chỉnh, vui lòng xem mô hình GPT-2 , sử dụng chức năng ánh xạ để khớp với sơ đồ đặt tên chính xác được sử dụng cho các điểm kiểm tra của OpenAI.

Các thành phần CheckpointReader và CheckpointWriter

Để ghi điểm kiểm tra, tiện ích mở rộng do giao thức Checkpointable cung cấp sử dụng sự phản chiếu và đường dẫn khóa để lặp lại các thuộc tính của mô hình và tạo một từ điển ánh xạ các đường dẫn tensor chuỗi tới các giá trị Tensor. Từ điển này được cung cấp cho CheckpointWriter cơ bản, cùng với một thư mục để ghi điểm kiểm tra. CheckpointWriter đó xử lý nhiệm vụ tạo điểm kiểm tra trên đĩa từ từ điển đó.

Ngược lại của quá trình này là đọc, trong đó CheckpointReader được cung cấp vị trí của thư mục điểm kiểm tra trên đĩa. Sau đó, nó đọc từ điểm kiểm tra đó và tạo thành một từ điển ánh xạ tên của các tensor trong điểm kiểm tra với các giá trị đã lưu của chúng. Từ điển này được sử dụng để thay thế các tensor hiện tại trong mô hình bằng các tensor trong từ điển này.

Đối với cả việc tải và lưu, giao thức Checkpointable ánh xạ các đường dẫn chuỗi tới các tenxơ thành các tên tenxơ tương ứng trên đĩa bằng cách sử dụng chức năng ánh xạ được mô tả ở trên.

Nếu giao thức Checkpointable thiếu chức năng cần thiết hoặc muốn có nhiều quyền kiểm soát hơn trong quá trình tải và lưu điểm kiểm tra thì các lớp CheckpointReaderCheckpointWriter có thể được sử dụng riêng.

Định dạng điểm kiểm tra TensorFlow v2

Định dạng điểm kiểm tra TensorFlow v2, như được mô tả ngắn gọn trong tiêu đề này , là định dạng thế hệ thứ hai cho các điểm kiểm tra mô hình TensorFlow. Định dạng thế hệ thứ hai này đã được sử dụng từ cuối năm 2016 và có một số cải tiến so với định dạng điểm kiểm tra v1. TensorFlow SavingModels sử dụng các điểm kiểm tra v2 bên trong chúng để lưu các tham số mô hình.

Điểm kiểm tra TensorFlow v2 bao gồm một thư mục có cấu trúc như sau:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

trong đó tệp đầu tiên lưu trữ siêu dữ liệu cho điểm kiểm tra và các tệp còn lại là các phân đoạn nhị phân chứa các tham số được tuần tự hóa cho mô hình.

Tệp siêu dữ liệu chỉ mục chứa loại, kích thước, vị trí và tên chuỗi của tất cả các tensor được tuần tự hóa có trong các phân đoạn. Tệp chỉ mục đó là phần có cấu trúc phức tạp nhất của điểm kiểm tra và dựa trên tensorflow::table , bản thân nó dựa trên SSTable/LevelDB. Tệp chỉ mục này bao gồm một loạt các cặp khóa-giá trị, trong đó khóa là chuỗi và giá trị là bộ đệm giao thức. Các chuỗi được sắp xếp và nén tiền tố. Ví dụ: nếu mục nhập đầu tiên là conv1/weight và tiếp theo conv1/bias thì mục nhập thứ hai chỉ sử dụng phần bias .

Tệp chỉ mục tổng thể này đôi khi được nén bằng cách nén Snappy . Tệp SnappyDecompression.swift cung cấp cách triển khai Swift giải nén Snappy gốc từ một phiên bản Dữ liệu đã nén.

Siêu dữ liệu tiêu đề chỉ mục và siêu dữ liệu tensor được mã hóa dưới dạng bộ đệm giao thức và được mã hóa/giải mã trực tiếp thông qua Swift Protobuf .

Các lớp CheckpointIndexReaderCheckpointIndexWriter xử lý việc tải và lưu các tệp chỉ mục này như một phần của các lớp CheckpointReaderCheckpointWriter tổng thể. Cái sau sử dụng các tệp chỉ mục làm cơ sở để xác định nội dung cần đọc và ghi vào các phân đoạn nhị phân có cấu trúc đơn giản hơn có chứa dữ liệu tensor.