Các mô hình dựa trên TensorFlow GraphDef (thường được tạo thông qua API Python) có thể được lưu ở một trong các định dạng sau:
- Mô hình đã lưu TensorFlow
- người mẫu đông lạnh
- Mô-đun Tensorflow Hub
Tất cả các định dạng trên có thể được chuyển đổi bằng trình chuyển đổi TensorFlow.js thành định dạng có thể được tải trực tiếp vào TensorFlow.js để suy luận.
(Lưu ý: TensorFlow đã ngừng sử dụng định dạng gói phiên. Vui lòng di chuyển các mô hình của bạn sang định dạng SavingModel.)
Yêu cầu
Quy trình chuyển đổi yêu cầu môi trường Python; bạn có thể muốn giữ một cái biệt lập bằng cách sử dụng pipenv hoặc virtualenv .
Để cài đặt bộ chuyển đổi, hãy chạy lệnh sau:
pip install tensorflowjs
Nhập mô hình TensorFlow vào TensorFlow.js là quy trình gồm hai bước. Đầu tiên, chuyển đổi một mô hình hiện có sang định dạng web TensorFlow.js, sau đó tải mô hình đó vào TensorFlow.js.
Bước 1. Chuyển đổi mô hình TensorFlow hiện có sang định dạng web TensorFlow.js
Chạy tập lệnh chuyển đổi do gói pip cung cấp:
Ví dụ về SavingModel:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Ví dụ về mô hình đông lạnh:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Ví dụ về mô-đun Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
Đối số vị trí | Sự miêu tả |
---|---|
input_path | Đường dẫn đầy đủ của thư mục mô hình đã lưu, thư mục gói phiên, tệp mô hình bị đóng băng hoặc đường dẫn hoặc xử lý mô-đun TensorFlow Hub. |
output_path | Đường dẫn cho tất cả các tạo phẩm đầu ra. |
Tùy chọn | Sự miêu tả |
---|---|
--input_format | Định dạng của mô hình đầu vào. Sử dụng tf_saved_model cho SavingModel, tf_frozen_model cho mô hình cố định, tf_session_bundle cho gói phiên, tf_hub cho mô-đun TensorFlow Hub và máy ảnh cho Keras HDF5. |
--output_node_names | Tên của các nút đầu ra, được phân tách bằng dấu phẩy. |
--saved_model_tags | Chỉ áp dụng cho chuyển đổi SavingModel. Các thẻ của MetaGraphDef cần tải, ở định dạng được phân tách bằng dấu phẩy. Mặc định để serve . |
--signature_name | Chỉ áp dụng cho chuyển đổi mô-đun TensorFlow Hub, chữ ký thành tải. Mặc định để default . Xem https://www.tensorflow.org/hub/common_signatures/ |
Sử dụng lệnh sau để nhận thông báo trợ giúp chi tiết:
tensorflowjs_converter --help
Chuyển đổi tập tin được tạo
Tập lệnh chuyển đổi ở trên tạo ra hai loại tệp:
-
model.json
: Biểu đồ luồng dữ liệu và bảng kê khai trọng lượng -
group1-shard\*of\*
: Tập hợp các tệp trọng số nhị phân
Ví dụ: đây là kết quả từ việc chuyển đổi MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
Bước 2: Load và chạy trên trình duyệt
- Cài đặt gói tfjs-converter npm:
yarn add @tensorflow/tfjs
hoặc npm install @tensorflow/tfjs
- Khởi tạo lớp FrozenModel và chạy suy luận.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
Kiểm tra bản trình diễn MobileNet .
API loadGraphModel
chấp nhận tham số LoadOptions
bổ sung, tham số này có thể được sử dụng để gửi thông tin đăng nhập hoặc tiêu đề tùy chỉnh cùng với yêu cầu. Để biết chi tiết, hãy xem tài liệu loadGraphModel() .
hoạt động được hỗ trợ
Hiện tại TensorFlow.js hỗ trợ một nhóm hoạt động TensorFlow giới hạn. Nếu mô hình của bạn sử dụng op không được hỗ trợ, tập lệnh tensorflowjs_converter
sẽ không thành công và in ra danh sách các op không được hỗ trợ trong mô hình của bạn. Vui lòng gửi một vấn đề cho mỗi op để cho chúng tôi biết op nào bạn cần hỗ trợ.
Chỉ tải trọng lượng
Nếu bạn chỉ muốn tải các trọng số, bạn có thể sử dụng đoạn mã sau:
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
// Use `weightMap` ...