یک مدل TensorFlow را به TensorFlow.js وارد کنید

مدل های مبتنی بر TensorFlow GraphDef (معمولاً از طریق Python API ایجاد می شوند) می توانند در یکی از فرمت های زیر ذخیره شوند:

  1. TensorFlow SavedModel
  2. مدل منجمد
  3. ماژول تنسورفلو هاب

همه فرمت های فوق را می توان توسط مبدل TensorFlow.js به قالبی تبدیل کرد که می تواند مستقیماً برای استنباط در TensorFlow.js بارگذاری شود.

(توجه: TensorFlow قالب session bundle را منسوخ کرده است. لطفاً مدل‌های خود را به قالب SavedModel منتقل کنید.)

الزامات

روش تبدیل به یک محیط پایتون نیاز دارد. ممکن است بخواهید با استفاده از pipenv یا virtualenv یک ایزوله نگه دارید.

برای نصب مبدل دستور زیر را اجرا کنید:

 pip install tensorflowjs

وارد کردن یک مدل TensorFlow به TensorFlow.js یک فرآیند دو مرحله ای است. ابتدا یک مدل موجود را به قالب وب TensorFlow.js تبدیل کنید و سپس آن را در TensorFlow.js بارگذاری کنید.

مرحله 1. یک مدل TensorFlow موجود را به قالب وب TensorFlow.js تبدیل کنید

اسکریپت مبدل ارائه شده توسط بسته pip را اجرا کنید:

مثال SavedModel:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

نمونه مدل منجمد:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

مثال ماژول Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
استدلال های موضعی شرح
input_path مسیر کامل دایرکتوری مدل ذخیره شده، پوشه بسته جلسه، فایل مدل ثابت یا دسته یا مسیر ماژول TensorFlow Hub.
output_path مسیر برای تمام مصنوعات خروجی.
گزینه ها شرح
--input_format قالب مدل ورودی از tf_saved_model برای SavedModel، tf_frozen_model برای مدل فریز شده، tf_session_bundle برای Session bundle، tf_hub برای ماژول TensorFlow Hub و keras برای Keras HDF5 استفاده کنید.
--output_node_names نام گره های خروجی که با کاما از هم جدا شده اند.
--saved_model_tags فقط برای تبدیل SavedModel قابل استفاده است. برچسب های MetaGraphDef برای بارگیری، در قالب جدا شده با کاما. پیش‌فرض برای serve
--signature_name فقط برای تبدیل ماژول TensorFlow Hub، امضا به بارگذاری قابل اجرا است. پیش‌فرض به default . به https://www.tensorflow.org/hub/common_signatures/ مراجعه کنید

از دستور زیر برای دریافت پیام راهنمای دقیق استفاده کنید:

tensorflowjs_converter --help

تبدیل فایل های تولید شده

اسکریپت تبدیل بالا دو نوع فایل تولید می کند:

  • model.json : نمودار جریان داده و وزن آشکار می شود
  • group1-shard\*of\* : مجموعه ای از فایل های وزن باینری

به عنوان مثال، در اینجا خروجی تبدیل MobileNet v2 است:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

مرحله 2: بارگیری و اجرا در مرورگر

  1. بسته tfjs-converter npm را نصب کنید:

yarn add @tensorflow/tfjs یا npm install @tensorflow/tfjs

  1. کلاس FrozenModel را نمونه سازی کنید و استنتاج را اجرا کنید.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

نسخه ی نمایشی MobileNet را بررسی کنید.

loadGraphModel API یک پارامتر LoadOptions اضافی را می پذیرد، که می تواند برای ارسال اعتبار یا هدرهای سفارشی همراه با درخواست استفاده شود. برای جزئیات، به مستندات loadGraphModel() مراجعه کنید.

عملیات پشتیبانی شده

در حال حاضر TensorFlow.js از مجموعه محدودی از عملیات های TensorFlow پشتیبانی می کند. اگر مدل شما از یک عملیات پشتیبانی‌نشده استفاده می‌کند، اسکریپت tensorflowjs_converter خراب می‌شود و فهرستی از عملیات‌های پشتیبانی‌نشده در مدل شما را چاپ می‌کند. لطفاً برای هر عملیات مشکلی ایجاد کنید تا به ما اطلاع دهید که برای کدام عملیات نیاز به پشتیبانی دارید.

بارگیری وزنه ها فقط

اگر ترجیح می دهید فقط وزنه ها را بارگیری کنید، می توانید از قطعه کد زیر استفاده کنید:

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...