このページは Cloud Translation API によって翻訳されました。
Switch to English

TensorFlow GraphDefベースのモデルをTensorFlow.jsにインポートする

TensorFlow GraphDefベースのモデル(通常はPython APIを介して作成されます)は、次のいずれかの形式で保存できます。

  1. TensorFlow SavedModel
  2. 冷凍モデル
  3. セッションバンドル
  4. Tensorflow Hubモジュール

上記のすべての形式は、 TensorFlow.jsコンバーターによって、推論のためにTensorFlow.jsに直接ロードできる形式に変換できます。

(注:TensorFlowはセッションバンドル形式を非推奨にしています。モデルをSavedModel形式に移行してください。)

必要条件

変換手順にはPython環境が必要です。 pipenvまたはvirtualenvを使用して、隔離されたものを保持することもできます。コンバーターをインストールするには、次のコマンドを実行します。

 pip install tensorflowjs

TensorFlow.jsへのTensorFlowモデルのインポートは、2ステップのプロセスです。まず、既存のモデルをTensorFlow.jsウェブ形式に変換し、それをTensorFlow.jsに読み込みます。

ステップ1.既存のTensorFlowモデルをTensorFlow.jsウェブ形式に変換する

pipパッケージで提供されているコンバータースクリプトを実行します。

使用法:SavedModelの例:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

凍結モデルの例:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Tensorflow Hubモジュールの例:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
位置引数説明
input_path 保存されたモデルディレクトリ、セッションバンドルディレクトリ、フリーズされたモデルファイル、またはTensorFlow Hubモジュールのハンドルまたはパスの完全パス。
output_path すべての出力アーティファクトのパス。
オプション説明
--input_format 入力モデルの形式。SavedModelにはtf_saved_model、凍結モデルにはtf_frozen_model、セッションバンドルにはtf_session_bundle、TensorFlow Hubモジュールにはtf_hub、Keras HDF5にはkerasを使用します。
--output_node_names コンマで区切られた出力ノードの名前。
--saved_model_tags SavedModel変換、ロードするMetaGraphDefのタグにのみ適用可能で、コンマ区切り形式です。デフォルトはserve
--signature_name TensorFlow Hubモジュール変換にのみ適用され、署名を読み込みます。デフォルトはdefaulthttps://www.tensorflow.org/hub/common_signatures/を参照してください

次のコマンドを使用して、詳細なヘルプメッセージを取得します。

tensorflowjs_converter --help

コンバーターが生成したファイル

上記の変換スクリプトは、2種類のファイルを生成します。

  • model.json (データフローグラフとウェイトマニフェスト)
  • group1-shard\*of\* (バイナリウェイトファイルのコレクション)

たとえば、MobileNet v2の変換結果は次のとおりです。

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

ステップ2:ブラウザーでのロードと実行

  1. tfjs-converter npmパッケージをインストールする

yarn add @tensorflow/tfjsまたはnpm install @tensorflow/tfjs

  1. FrozenModelクラスをインスタンス化し、推論を実行します。
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

MobileNetデモをご覧ください。

loadGraphModel APIは追加のLoadOptionsパラメータを受け入れます。このパラメータを使用して、リクエストとともに認証情報またはカスタムヘッダーを送信できます。詳細については、 loadGraphModel()のドキュメントをご覧ください。

サポートされる操作

現在、TensorFlow.jsはTensorFlow演算の限定されたセットをサポートしています。モデルでサポートされていないopが使用されている場合、 tensorflowjs_converterスクリプトは失敗し、モデルでサポートされていないopのリストが出力されます。各オペレーションについて問題を報告し、サポートが必要なオペレーションをお知らせください。

ウェイトのみをロードする

ウェイトのみをロードする場合は、次のコードスニペットを使用できます。

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");