Zapisz datę! Google I / O powraca w dniach 18-20 maja Zarejestruj się teraz
Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Importowanie modeli opartych na TensorFlow GraphDef do TensorFlow.js

Modele oparte na TensorFlow GraphDef (zazwyczaj tworzone za pośrednictwem interfejsu API języka Python) można zapisać w jednym z następujących formatów:

  1. TensorFlow SavedModel
  2. Model zamrożony
  3. Moduł koncentratora Tensorflow

Wszystkie powyższe formaty można przekonwertować za pomocą konwertera TensorFlow.js do formatu, który można załadować bezpośrednio do TensorFlow.js w celu wnioskowania.

(Uwaga: TensorFlow wycofał format pakietu sesji, przenieś swoje modele do formatu SavedModel).

Wymagania

Procedura konwersji wymaga środowiska Python; możesz zachować izolację za pomocą pipenv lub virtualenv . Aby zainstalować konwerter, uruchom następujące polecenie:

 pip install tensorflowjs

Importowanie modelu TensorFlow do TensorFlow.js to proces dwuetapowy. Najpierw przekonwertuj istniejący model na format sieciowy TensorFlow.js, a następnie załaduj go do TensorFlow.js.

Krok 1. Przekonwertuj istniejący model TensorFlow na format sieciowy TensorFlow.js.

Uruchom skrypt konwertera dostarczony przez pakiet pip:

Użycie: Przykład SavedModel:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Przykład zamrożonego modelu:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Przykład modułu Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumenty pozycyjne Opis
input_path Pełna ścieżka do zapisanego katalogu modelu, katalogu pakietu sesji, zamrożonego pliku modelu lub uchwytu lub ścieżki modułu TensorFlow Hub.
output_path Ścieżka do wszystkich artefaktów wyjściowych.
Opcje Opis
--input_format Format modelu wejściowego, użyj tf_saved_model dla SavedModel, tf_frozen_model dla zamrożonego modelu, tf_session_bundle dla pakietu sesji, tf_hub dla modułu TensorFlow Hub i keras dla Keras HDF5.
--output_node_names Nazwy węzłów wyjściowych oddzielone przecinkami.
--saved_model_tags Ma zastosowanie tylko do konwersji SavedModel, tagów MetaGraphDef do załadowania, w formacie oddzielonym przecinkami. Domyślnie serve .
--signature_name Ma zastosowanie tylko do konwersji modułu TensorFlow Hub, podpis do załadowania. Domyślnie default . Zobacz https://www.tensorflow.org/hub/common_signatures/

Użyj następującego polecenia, aby uzyskać szczegółową wiadomość pomocy:

tensorflowjs_converter --help

Pliki wygenerowane przez konwerter

Powyższy skrypt konwersji tworzy dwa typy plików:

  • model.json (wykres przepływu danych i manifest wagi)
  • group1-shard\*of\* (kolekcja binarnych plików wagi)

Na przykład, oto wynik konwersji MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Krok 2: Ładowanie i uruchamianie w przeglądarce

  1. Zainstaluj pakiet tfjs-converter npm

yarn add @tensorflow/tfjs lub npm install @tensorflow/tfjs

  1. Utwórz wystąpienie klasy FrozenModel i uruchom wnioskowanie.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Sprawdź nasze demo MobileNet .

loadGraphModel API loadGraphModel akceptuje dodatkowy parametr LoadOptions , którego można użyć do wysłania poświadczeń lub niestandardowych nagłówków wraz z żądaniem. Więcej informacji można znaleźć w dokumentacji loadGraphModel () .

Obsługiwane operacje

Obecnie TensorFlow.js obsługuje ograniczony zestaw operacji TensorFlow. Jeśli model korzysta z nieobsługiwanych tensorflowjs_converter skrypt tensorflowjs_converter nie powiedzie się i wydrukuje listę nieobsługiwanych operacji w modelu. Zgłoś problem dla każdej operacji, aby poinformować nas, w których operacjach potrzebujesz wsparcia.

Ładowanie tylko ciężarków

Jeśli wolisz załadować tylko wagi, możesz użyć następującego fragmentu kodu.

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");