Nhập Mô hình dựa trên TensorFlow GraphDef vào TensorFlow.js

Sử dụng bộ sưu tập để sắp xếp ngăn nắp các trang Lưu và phân loại nội dung dựa trên lựa chọn ưu tiên của bạn.

Các mô hình dựa trên TensorFlow GraphDef (thường được tạo thông qua API Python) có thể được lưu ở một trong các định dạng sau:

  1. TensorFlow SavedModel
  2. Mô hình đông lạnh
  3. Mô-đun Tensorflow Hub

Tất cả các định dạng trên có thể được chuyển đổi bởi trình chuyển đổi TensorFlow.js thành một định dạng có thể tải trực tiếp vào TensorFlow.js để suy luận.

(Lưu ý: TensorFlow không dùng định dạng gói phiên nữa, vui lòng chuyển mô hình của bạn sang định dạng SavedModel.)

Yêu cầu

Thủ tục chuyển đổi yêu cầu môi trường Python; bạn có thể muốn giữ một cái riêng biệt bằng cách sử dụng pipenv hoặc virtualenv . Để cài đặt bộ chuyển đổi, hãy chạy lệnh sau:

 pip install tensorflowjs

Nhập một mô hình TensorFlow vào TensorFlow.js là một quá trình gồm hai bước. Đầu tiên, chuyển đổi mô hình hiện có sang định dạng web TensorFlow.js, rồi tải nó vào TensorFlow.js.

Bước 1. Chuyển đổi mô hình TensorFlow hiện có sang định dạng web TensorFlow.js

Chạy tập lệnh chuyển đổi được cung cấp bởi gói pip:

Cách sử dụng: Ví dụ SavedModel:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Ví dụ về mô hình đông lạnh:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Ví dụ về mô-đun Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Lập luận vị trí Sự mô tả
input_path Đường dẫn đầy đủ của thư mục mô hình đã lưu, thư mục gói phiên, tệp mô hình cố định hoặc đường dẫn hoặc tay cầm mô-đun TensorFlow Hub.
output_path Đường dẫn cho tất cả các tạo tác đầu ra.
Tùy chọn Sự mô tả
--input_format Định dạng của mô hình đầu vào, sử dụng tf_saved_model cho SavedModel, tf_frozen_model cho mô hình cố định, tf_session_bundle cho gói phiên, tf_hub cho mô-đun TensorFlow Hub và keras cho Keras HDF5.
--output_node_names Tên của các nút đầu ra, được phân tách bằng dấu phẩy.
--saved_model_tags Chỉ áp dụng cho chuyển đổi SavedModel, Thẻ của MetaGraphDef để tải, ở định dạng được phân tách bằng dấu phẩy. Mặc định để phân serve .
--signature_name Chỉ áp dụng cho chuyển đổi mô-đun TensorFlow Hub, chữ ký để tải. Mặc định là default . Xem https://www.tensorflow.org/hub/common_signatures/

Sử dụng lệnh sau để nhận thông báo trợ giúp chi tiết:

tensorflowjs_converter --help

Chuyển đổi các tệp được tạo

Tập lệnh chuyển đổi ở trên tạo ra hai loại tệp:

  • model.json (biểu đồ luồng dữ liệu và tệp kê khai trọng lượng)
  • group1-shard\*of\* (tập hợp các tệp trọng số nhị phân)

Ví dụ: đây là kết quả từ việc chuyển đổi MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Bước 2: Đang tải và chạy trong trình duyệt

  1. Cài đặt gói tfjs-converter npm

yarn add @tensorflow/tfjs hoặc npm npm install @tensorflow/tfjs

  1. Khởi tạo lớp FrozenModel và chạy suy luận.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Kiểm tra bản demo MobileNet của chúng tôi.

API loadGraphModel chấp nhận một tham số LoadOptions bổ sung, có thể được sử dụng để gửi thông tin xác thực hoặc tiêu đề tùy chỉnh cùng với yêu cầu. Vui lòng xem tài liệu loadGraphModel () để biết thêm chi tiết.

Các hoạt động được hỗ trợ

Hiện tại TensorFlow.js hỗ trợ một số hoạt động TensorFlow giới hạn. Nếu mô hình của bạn sử dụng op không được hỗ trợ, tập lệnh tensorflowjs_converter sẽ không thành công và in ra danh sách các hoạt động không được hỗ trợ trong mô hình của bạn. Vui lòng gửi vấn đề cho mỗi hoạt động để cho chúng tôi biết bạn cần hỗ trợ cho hoạt động nào.

Chỉ tải trọng lượng

Nếu bạn chỉ muốn tải trọng số, bạn có thể sử dụng đoạn mã sau.

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");