มีคำถาม? เชื่อมต่อกับชุมชนที่ฟอรัม TensorFlow เยี่ยมชมฟอรัม

การนำเข้าโมเดลที่ใช้ TensorFlow GraphDef ไปยัง TensorFlow.js

โมเดลที่ใช้ TensorFlow GraphDef (โดยทั่วไปสร้างผ่าน Python API) อาจถูกบันทึกในรูปแบบใดรูปแบบหนึ่งต่อไปนี้:

  1. TensorFlow ที่ บันทึกไว้
  2. แบบจำลองแช่แข็ง
  3. โมดูล Tensorflow Hub

รูปแบบทั้งหมดข้างต้นสามารถแปลงได้โดยตัว แปลง TensorFlow.js เป็นรูปแบบที่สามารถโหลดลงใน TensorFlow.js ได้โดยตรงเพื่อการอนุมาน

(หมายเหตุ: TensorFlow ได้เลิกใช้งานรูปแบบเซสชันบันเดิลแล้วโปรดย้ายโมเดลของคุณเป็นรูปแบบ SavedModel)

ข้อกำหนด

ขั้นตอนการแปลงต้องใช้สภาพแวดล้อม Python คุณอาจต้องการที่จะให้แยกโดยใช้ pipenv หรือ virtualenv ในการติดตั้งตัวแปลงให้รันคำสั่งต่อไปนี้:

 pip install tensorflowjs

การนำเข้าแบบจำลอง TensorFlow ไปยัง TensorFlow.js เป็นกระบวนการสองขั้นตอน ขั้นแรกให้แปลงโมเดลที่มีอยู่เป็นรูปแบบเว็บ TensorFlow.js จากนั้นโหลดลงใน TensorFlow.js

ขั้นตอนที่ 1. แปลงโมเดล TensorFlow ที่มีอยู่เป็นรูปแบบเว็บ TensorFlow.js

รันสคริปต์ตัวแปลงที่จัดเตรียมโดยแพ็คเกจ pip:

การใช้งาน: ตัวอย่าง SafedModel:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

ตัวอย่างโมเดลที่ถูกแช่แข็ง:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

ตัวอย่างโมดูล Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
อาร์กิวเมนต์ตำแหน่ง คำอธิบาย
input_path พา ธ แบบเต็มของไดเร็กทอรีโมเดลที่บันทึกไว้ไดเร็กทอรีบันเดิลเซสชันไฟล์โมเดลที่ถูกตรึงหรือแฮนเดิลของโมดูล TensorFlow Hub หรือพา ธ
output_path เส้นทางสำหรับอาร์ติแฟกต์เอาต์พุตทั้งหมด
ตัวเลือก คำอธิบาย
--input_format รูปแบบของโมเดลอินพุตใช้ tf_saved_model สำหรับ SavedModel, tf_frozen_model สำหรับโมเดลที่แช่แข็ง, tf_session_bundle สำหรับเซสชันบันเดิล, tf_hub สำหรับโมดูล TensorFlow Hub และ Keras สำหรับ Keras HDF5
--output_node_names ชื่อของโหนดเอาต์พุตคั่นด้วยเครื่องหมายจุลภาค
--saved_model_tags ใช้ได้เฉพาะกับการแปลง SavedModel แท็กของ MetaGraphDef ที่จะโหลดในรูปแบบที่คั่นด้วยลูกน้ำ ค่าเริ่มต้นที่จะ serve
--signature_name ใช้ได้เฉพาะกับการแปลงโมดูล TensorFlow Hub ลายเซ็นในการโหลด ค่า default ดู https://www.tensorflow.org/hub/common_signatures/

ใช้คำสั่งต่อไปนี้เพื่อรับข้อความช่วยเหลือโดยละเอียด:

tensorflowjs_converter --help

แปลงไฟล์ที่สร้างขึ้น

สคริปต์การแปลงด้านบนสร้างไฟล์สองประเภท:

  • model.json (กราฟกระแสข้อมูลและรายการน้ำหนัก)
  • group1-shard\*of\* (คอลเลกชันของไฟล์ไบนารีน้ำหนัก)

ตัวอย่างเช่นนี่คือผลลัพธ์จากการแปลง MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

ขั้นตอนที่ 2: โหลดและรันในเบราว์เซอร์

  1. ติดตั้งแพ็คเกจ tfjs-converter npm

yarn add @tensorflow/tfjs หรือ npm install @tensorflow/tfjs

  1. สร้าง อินสแตนซ์คลาส FrozenModel และเรียกใช้การอนุมาน
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

ดูการ สาธิต MobileNet ของเรา

loadGraphModel API ยอมรับพารามิเตอร์ LoadOptions เพิ่มเติมซึ่งสามารถใช้เพื่อส่งหนังสือรับรองหรือส่วนหัวที่กำหนดเองพร้อมกับคำขอ โปรดดูเอกสาร loadGraphModel () สำหรับรายละเอียดเพิ่มเติม

รองรับการดำเนินงาน

ปัจจุบัน TensorFlow.js รองรับชุดปฏิบัติการ TensorFlow ที่ จำกัด หากโมเดลของคุณใช้ op ที่ไม่รองรับสคริปต์ tensorflowjs_converter จะล้มเหลวและพิมพ์รายการตัวเลือกที่ไม่รองรับในโมเดลของคุณ โปรดยื่น ปัญหา สำหรับแต่ละหน่วยงานเพื่อแจ้งให้เราทราบว่าคุณต้องการการสนับสนุนหน่วยงานใด

โหลดน้ำหนักเท่านั้น

หากคุณต้องการโหลดเฉพาะน้ำหนักคุณสามารถใช้ข้อมูลโค้ดต่อไปนี้

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");