RSVP для вашего местного мероприятия TensorFlow Everywhere сегодня!
Эта страница переведена с помощью Cloud Translation API.
Switch to English

Импорт моделей на основе TensorFlow GraphDef в TensorFlow.js

Модели на основе TensorFlow GraphDef (обычно создаваемые через Python API) могут быть сохранены в одном из следующих форматов:

  1. TensorFlow SavedModel
  2. Замороженная модель
  3. Пакет сеанса
  4. Модуль Tensorflow Hub

Все вышеперечисленные форматы могут быть преобразованы конвертером TensorFlow.js в формат, который можно загрузить непосредственно в TensorFlow.js для вывода.

(Примечание: TensorFlow устарел формат пакета сеанса, пожалуйста, перенесите свои модели в формат SavedModel.)

Требования

Для процедуры преобразования требуется среда Python; вы можете захотеть сохранить изолированный, используя pipenv или virtualenv . Чтобы установить конвертер, выполните следующую команду:

 pip install tensorflowjs

Импорт модели TensorFlow в TensorFlow.js - это двухэтапный процесс. Сначала преобразуйте существующую модель в веб-формат TensorFlow.js, а затем загрузите ее в TensorFlow.js.

Шаг 1. Преобразование существующей модели TensorFlow в веб-формат TensorFlow.js

Запустите сценарий конвертера, предоставленный пакетом pip:

Использование: Пример SavedModel:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Пример замороженной модели:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Пример модуля Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Позиционные аргументы Описание
input_path Полный путь к каталогу сохраненной модели, каталог пакета сеанса, файл замороженной модели, дескриптор или путь модуля TensorFlow Hub.
output_path Путь для всех артефактов вывода.
Параметры Описание
--input_format Формат входной модели: tf_saved_model для SavedModel, tf_frozen_model для замороженной модели, tf_session_bundle для пакета сеанса, tf_hub для модуля TensorFlow Hub и keras для Keras HDF5.
--output_node_names Имена выходных узлов, разделенные запятыми.
--saved_model_tags Применимо только к преобразованию SavedModel, тегам MetaGraphDef для загрузки в формате, разделенном запятыми. По умолчанию serve .
--signature_name Применимо только для преобразования модуля TensorFlow Hub, подпись для загрузки. По умолчанию по default . См. Https://www.tensorflow.org/hub/common_signatures/

Используйте следующую команду, чтобы получить подробное справочное сообщение:

tensorflowjs_converter --help

Конвертер сгенерированных файлов

Приведенный выше сценарий преобразования создает файлы двух типов:

  • model.json (график потока данных и манифест веса)
  • group1-shard\*of\* (набор файлов двоичных весов)

Например, вот результат преобразования MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Шаг 2. Загрузка и запуск в браузере

  1. Установите пакет npm tfjs-converter

yarn add @tensorflow/tfjs или npm install @tensorflow/tfjs

  1. Создайте экземпляр класса FrozenModel и выполните вывод.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Посмотрите нашу демонстрацию MobileNet .

API loadGraphModel принимает дополнительный параметр LoadOptions , который можно использовать для отправки учетных данных или пользовательских заголовков вместе с запросом. Пожалуйста, смотрите документацию loadGraphModel () для получения более подробной информации.

Поддерживаемые операции

В настоящее время TensorFlow.js поддерживает ограниченный набор операций TensorFlow. Если ваша модель использует неподдерживаемую tensorflowjs_converter сценарий tensorflowjs_converter завершится ошибкой и распечатает список неподдерживаемых операций в вашей модели. Пожалуйста, отправьте вопрос для каждой операции, чтобы сообщить нам, для каких операций вам нужна поддержка.

Загрузка только веса

Если вы предпочитаете загружать только веса, вы можете использовать следующий фрагмент кода.

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");