Các mô hình dựa trên TensorFlow GraphDef (thường được tạo thông qua API Python) có thể được lưu ở một trong các định dạng sau:
- TensorFlow SavedModel
- Mô hình đông lạnh
- Mô-đun Tensorflow Hub
Tất cả các định dạng trên có thể được chuyển đổi bởi trình chuyển đổi TensorFlow.js thành một định dạng có thể tải trực tiếp vào TensorFlow.js để suy luận.
(Lưu ý: TensorFlow không dùng định dạng gói phiên nữa, vui lòng chuyển mô hình của bạn sang định dạng SavedModel.)
Yêu cầu
Thủ tục chuyển đổi yêu cầu môi trường Python; bạn có thể muốn giữ một cái riêng biệt bằng cách sử dụng pipenv hoặc virtualenv . Để cài đặt bộ chuyển đổi, hãy chạy lệnh sau:
pip install tensorflowjs
Nhập một mô hình TensorFlow vào TensorFlow.js là một quá trình gồm hai bước. Đầu tiên, chuyển đổi mô hình hiện có sang định dạng web TensorFlow.js, rồi tải nó vào TensorFlow.js.
Bước 1. Chuyển đổi mô hình TensorFlow hiện có sang định dạng web TensorFlow.js
Chạy tập lệnh chuyển đổi được cung cấp bởi gói pip:
Cách sử dụng: Ví dụ SavedModel:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Ví dụ về mô hình đông lạnh:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Ví dụ về mô-đun Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
Lập luận vị trí | Sự mô tả |
---|---|
input_path | Đường dẫn đầy đủ của thư mục mô hình đã lưu, thư mục gói phiên, tệp mô hình cố định hoặc đường dẫn hoặc tay cầm mô-đun TensorFlow Hub. |
output_path | Đường dẫn cho tất cả các tạo tác đầu ra. |
Tùy chọn | Sự mô tả |
---|---|
--input_format | Định dạng của mô hình đầu vào, sử dụng tf_saved_model cho SavedModel, tf_frozen_model cho mô hình cố định, tf_session_bundle cho gói phiên, tf_hub cho mô-đun TensorFlow Hub và keras cho Keras HDF5. |
--output_node_names | Tên của các nút đầu ra, được phân tách bằng dấu phẩy. |
--saved_model_tags | Chỉ áp dụng cho chuyển đổi SavedModel, Thẻ của MetaGraphDef để tải, ở định dạng được phân tách bằng dấu phẩy. Mặc định để phân serve . |
--signature_name | Chỉ áp dụng cho chuyển đổi mô-đun TensorFlow Hub, chữ ký để tải. Mặc định là default . Xem https://www.tensorflow.org/hub/common_signatures/ |
Sử dụng lệnh sau để nhận thông báo trợ giúp chi tiết:
tensorflowjs_converter --help
Chuyển đổi các tệp được tạo
Tập lệnh chuyển đổi ở trên tạo ra hai loại tệp:
-
model.json
(biểu đồ luồng dữ liệu và tệp kê khai trọng lượng) -
group1-shard\*of\*
(tập hợp các tệp trọng số nhị phân)
Ví dụ: đây là kết quả từ việc chuyển đổi MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
Bước 2: Đang tải và chạy trong trình duyệt
- Cài đặt gói tfjs-converter npm
yarn add @tensorflow/tfjs
hoặc npm npm install @tensorflow/tfjs
- Khởi tạo lớp FrozenModel và chạy suy luận.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
Kiểm tra bản demo MobileNet của chúng tôi.
API loadGraphModel
chấp nhận một tham số LoadOptions
bổ sung, có thể được sử dụng để gửi thông tin xác thực hoặc tiêu đề tùy chỉnh cùng với yêu cầu. Vui lòng xem tài liệu loadGraphModel () để biết thêm chi tiết.
Các hoạt động được hỗ trợ
Hiện tại TensorFlow.js hỗ trợ một số hoạt động TensorFlow giới hạn. Nếu mô hình của bạn sử dụng op không được hỗ trợ, tập lệnh tensorflowjs_converter
sẽ không thành công và in ra danh sách các hoạt động không được hỗ trợ trong mô hình của bạn. Vui lòng gửi vấn đề cho mỗi hoạt động để cho chúng tôi biết bạn cần hỗ trợ cho hoạt động nào.
Chỉ tải trọng lượng
Nếu bạn chỉ muốn tải trọng số, bạn có thể sử dụng đoạn mã sau.
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");