Mengimpor Model berbasis TensorFlow GraphDef ke TensorFlow.js

Model berbasis TensorFlow GraphDef (biasanya dibuat melalui Python API) dapat disimpan dalam salah satu format berikut:

  1. TensorFlow SavedModel
  2. Model beku
  3. Modul Tensorflow Hub

Semua format diatas dapat dikonversi oleh TensorFlow.js converter ke format yang dapat dimuat langsung ke TensorFlow.js untuk inferensi.

(Catatan: TensorFlow tidak lagi menggunakan format bundel sesi, harap migrasikan model Anda ke format SavedModel.)

Persyaratan

Prosedur konversi membutuhkan lingkungan Python; Anda mungkin ingin menyimpan satu terisolasi menggunakan pipenv atau virtualenv . Untuk menginstal konverter, jalankan perintah berikut:

 pip install tensorflowjs

Mengimpor model TensorFlow ke TensorFlow.js adalah proses dua langkah. Pertama, konversi model yang ada ke format web TensorFlow.js, lalu muat ke TensorFlow.js.

Langkah 1. Konversi model TensorFlow yang ada ke format web TensorFlow.js

Jalankan skrip konverter yang disediakan oleh paket pip:

Penggunaan: Contoh Model Tersimpan:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Contoh model beku:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Contoh modul Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumen Posisi Keterangan
input_path Jalur lengkap direktori model yang disimpan, direktori bundel sesi, file model beku, atau pegangan atau jalur modul TensorFlow Hub.
output_path Jalur untuk semua artefak keluaran.
Pilihan Keterangan
--input_format Format model input, gunakan tf_saved_model untuk SavedModel, tf_frozen_model untuk model beku, tf_session_bundle untuk session bundle, tf_hub untuk modul TensorFlow Hub dan keras untuk Keras HDF5.
--output_node_names Nama-nama node keluaran, dipisahkan dengan koma.
--saved_model_tags Hanya berlaku untuk konversi SavedModel, Tag MetaGraphDef yang akan dimuat, dalam format yang dipisahkan koma. Default untuk serve .
--signature_name Hanya berlaku untuk konversi modul TensorFlow Hub, tanda tangan untuk dimuat. Defaultnya default . Lihat https://www.tensorflow.org/hub/common_signatures/

Gunakan perintah berikut untuk mendapatkan pesan bantuan terperinci:

tensorflowjs_converter --help

Konverter file yang dihasilkan

Script konversi di atas menghasilkan dua jenis file:

  • model.json (grafik dataflow dan manifest berat badan)
  • group1-shard\*of\* (kumpulan file berat biner)

Sebagai contoh, berikut adalah output dari konversi MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Langkah 2: Memuat dan menjalankan di browser

  1. Instal paket tfjs-converter npm

yarn add @tensorflow/tfjs atau npm install @tensorflow/tfjs

  1. Instantiate kelas FrozenModel dan menjalankan inferensi.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Lihat demo MobileNet .

The loadGraphModel API menerima tambahan LoadOptions parameter, yang dapat digunakan untuk mengirim kredensial atau header kustom bersama dengan permintaan. Silakan lihat loadGraphModel () dokumentasi untuk lebih jelasnya.

Operasi yang didukung

Saat ini TensorFlow.js mendukung serangkaian operasi TensorFlow terbatas. Jika model Anda menggunakan sebuah op yang tidak didukung, yang tensorflowjs_converter naskah akan gagal dan mencetak daftar ops tidak didukung dalam model Anda. Silakan mengajukan masalah untuk setiap op untuk memberitahu kami tahu mana ops Anda perlu dukungan untuk.

Memuat bobot saja

Jika Anda lebih suka memuat bobot saja, Anda dapat menggunakan cuplikan kode berikut.

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");