Modele oparte na TensorFlow GraphDef (zazwyczaj tworzone za pośrednictwem interfejsu API języka Python) można zapisać w jednym z następujących formatów:
- TensorFlow SavedModel
- Model zamrożony
- Moduł koncentratora Tensorflow
Wszystkie powyższe formaty można przekonwertować za pomocą konwertera TensorFlow.js do formatu, który można załadować bezpośrednio do TensorFlow.js w celu wnioskowania.
(Uwaga: TensorFlow wycofał format pakietu sesji, przenieś swoje modele do formatu SavedModel).
Wymagania
Procedura konwersji wymaga środowiska Python; możesz zachować izolację za pomocą pipenv lub virtualenv . Aby zainstalować konwerter, uruchom następujące polecenie:
pip install tensorflowjs
Importowanie modelu TensorFlow do TensorFlow.js to proces dwuetapowy. Najpierw przekonwertuj istniejący model na format sieciowy TensorFlow.js, a następnie załaduj go do TensorFlow.js.
Krok 1. Przekonwertuj istniejący model TensorFlow na format sieciowy TensorFlow.js.
Uruchom skrypt konwertera dostarczony przez pakiet pip:
Użycie: Przykład SavedModel:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Przykład zamrożonego modelu:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Przykład modułu Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
Argumenty pozycyjne | Opis |
---|---|
input_path | Pełna ścieżka do zapisanego katalogu modelu, katalogu pakietu sesji, zamrożonego pliku modelu lub uchwytu lub ścieżki modułu TensorFlow Hub. |
output_path | Ścieżka do wszystkich artefaktów wyjściowych. |
Opcje | Opis |
---|---|
--input_format | Format modelu wejściowego, użyj tf_saved_model dla SavedModel, tf_frozen_model dla zamrożonego modelu, tf_session_bundle dla pakietu sesji, tf_hub dla modułu TensorFlow Hub i keras dla Keras HDF5. |
--output_node_names | Nazwy węzłów wyjściowych oddzielone przecinkami. |
--saved_model_tags | Ma zastosowanie tylko do konwersji SavedModel, tagów MetaGraphDef do załadowania, w formacie oddzielonym przecinkami. Domyślnie serve . |
--signature_name | Ma zastosowanie tylko do konwersji modułu TensorFlow Hub, podpis do załadowania. Domyślnie default . Zobacz https://www.tensorflow.org/hub/common_signatures/ |
Użyj następującego polecenia, aby uzyskać szczegółową wiadomość pomocy:
tensorflowjs_converter --help
Pliki wygenerowane przez konwerter
Powyższy skrypt konwersji tworzy dwa typy plików:
-
model.json
(wykres przepływu danych i manifest wagi) -
group1-shard\*of\*
(kolekcja binarnych plików wagi)
Na przykład, oto wynik konwersji MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
Krok 2: Ładowanie i uruchamianie w przeglądarce
- Zainstaluj pakiet tfjs-converter npm
yarn add @tensorflow/tfjs
lub npm install @tensorflow/tfjs
- Utwórz wystąpienie klasy FrozenModel i uruchom wnioskowanie.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
Sprawdź nasze demo MobileNet .
loadGraphModel
API loadGraphModel
akceptuje dodatkowy parametr LoadOptions
, którego można użyć do wysłania poświadczeń lub niestandardowych nagłówków wraz z żądaniem. Więcej informacji można znaleźć w dokumentacji loadGraphModel () .
Obsługiwane operacje
Obecnie TensorFlow.js obsługuje ograniczony zestaw operacji TensorFlow. Jeśli model korzysta z nieobsługiwanych tensorflowjs_converter
skrypt tensorflowjs_converter
nie powiedzie się i wydrukuje listę nieobsługiwanych operacji w modelu. Zgłoś problem dla każdej operacji, aby poinformować nas, w których operacjach potrzebujesz wsparcia.
Ładowanie tylko ciężarków
Jeśli wolisz załadować tylko wagi, możesz użyć następującego fragmentu kodu.
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");