Модели на основе TensorFlow GraphDef (обычно создаваемые через Python API) могут быть сохранены в одном из следующих форматов:
Все вышеперечисленные форматы могут быть преобразованы конвертером TensorFlow.js в формат, который можно загрузить непосредственно в TensorFlow.js для вывода.
(Примечание: TensorFlow устарел формат пакета сеанса, пожалуйста, перенесите свои модели в формат SavedModel.)
Требования
Для процедуры преобразования требуется среда Python; вы можете захотеть сохранить изолированный, используя pipenv или virtualenv . Чтобы установить конвертер, выполните следующую команду:
pip install tensorflowjs
Импорт модели TensorFlow в TensorFlow.js - это двухэтапный процесс. Сначала преобразуйте существующую модель в веб-формат TensorFlow.js, а затем загрузите ее в TensorFlow.js.
Шаг 1. Преобразование существующей модели TensorFlow в веб-формат TensorFlow.js
Запустите сценарий конвертера, предоставленный пакетом pip:
Использование: Пример SavedModel:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Пример замороженной модели:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Пример модуля Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
Позиционные аргументы | Описание |
---|---|
input_path | Полный путь к каталогу сохраненной модели, каталог пакета сеанса, файл замороженной модели, дескриптор или путь модуля TensorFlow Hub. |
output_path | Путь для всех артефактов вывода. |
Параметры | Описание |
---|---|
--input_format | Формат входной модели: tf_saved_model для SavedModel, tf_frozen_model для замороженной модели, tf_session_bundle для пакета сеанса, tf_hub для модуля TensorFlow Hub и keras для Keras HDF5. |
--output_node_names | Имена выходных узлов, разделенные запятыми. |
--saved_model_tags | Применимо только к преобразованию SavedModel, тегам MetaGraphDef для загрузки в формате, разделенном запятыми. По умолчанию serve . |
--signature_name | Применимо только для преобразования модуля TensorFlow Hub, подпись для загрузки. По умолчанию по default . См. Https://www.tensorflow.org/hub/common_signatures/ |
Используйте следующую команду, чтобы получить подробное справочное сообщение:
tensorflowjs_converter --help
Конвертер сгенерированных файлов
Приведенный выше сценарий преобразования создает файлы двух типов:
-
model.json
(график потока данных и манифест веса) -
group1-shard\*of\*
(набор файлов двоичных весов)
Например, вот результат преобразования MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
Шаг 2. Загрузка и запуск в браузере
- Установите пакет npm tfjs-converter
yarn add @tensorflow/tfjs
или npm install @tensorflow/tfjs
- Создайте экземпляр класса FrozenModel и выполните вывод.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
Посмотрите нашу демонстрацию MobileNet .
API loadGraphModel
принимает дополнительный параметр LoadOptions
, который можно использовать для отправки учетных данных или пользовательских заголовков вместе с запросом. Пожалуйста, смотрите документацию loadGraphModel () для получения более подробной информации.
Поддерживаемые операции
В настоящее время TensorFlow.js поддерживает ограниченный набор операций TensorFlow. Если ваша модель использует неподдерживаемую tensorflowjs_converter
сценарий tensorflowjs_converter
завершится ошибкой и распечатает список неподдерживаемых операций в вашей модели. Пожалуйста, отправьте вопрос для каждой операции, чтобы сообщить нам, для каких операций вам нужна поддержка.
Загрузка только веса
Если вы предпочитаете загружать только веса, вы можете использовать следующий фрагмент кода.
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");