Сохранение и загрузка моделей

TensorFlow.js предоставляет функциональные возможности для сохранения и загрузки моделей, созданных с помощью Layers API или преобразованных из существующих моделей TensorFlow. Это могут быть модели, которых вы обучили сами, или модели, обученные другими. Ключевым преимуществом использования Layers API является то, что модели, созданные с его помощью, можно сериализовать, и это то, что мы рассмотрим в этом руководстве.

В этом руководстве основное внимание будет уделено сохранению и загрузке моделей TensorFlow.js (идентифицируемых по файлам JSON). Мы также можем импортировать модели TensorFlow Python. Загрузка этих моделей описана в следующих двух руководствах:

Сохраните tf.Model

tf.Model и tf.Sequential предоставляют функцию model.save , которая позволяет сохранять топологию и веса модели.

  • Топология: это файл, описывающий архитектуру модели (т.е. какие операции она использует). Он содержит ссылки на веса моделей, которые хранятся снаружи.

  • Веса: это двоичные файлы, в которых хранятся веса данной модели в эффективном формате. Обычно они хранятся в той же папке, что и топология.

Давайте посмотрим, как выглядит код сохранения модели

const saveResult = await model.save('localstorage://my-model-1');

Несколько вещей, на которые следует обратить внимание:

  • Метод save принимает строковый аргумент типа URL, который начинается со схемы . Это описывает тип места назначения, в котором мы пытаемся сохранить модель. В примере выше схема localstorage://
  • По схеме следует путь . В приведенном выше примере путь — my-model-1 .
  • Метод save является асинхронным.
  • Возвращаемое значение model.save — это объект JSON, который содержит такую ​​информацию, как размеры топологии модели в байтах и ​​ее веса.
  • Среда, используемая для сохранения модели, не влияет на то, какие среды могут загружать модель. Сохранение модели в node.js не препятствует ее загрузке в браузере.

Ниже мы рассмотрим различные доступные схемы.

Локальное хранилище (только браузер)

Схема: localstorage://

await model.save('localstorage://my-model');

При этом модель сохраняется под именем my-model в локальном хранилище браузера. Это будет сохраняться между обновлениями, хотя локальное хранилище может быть очищено пользователями или самим браузером, если пространство становится проблемой. Каждый браузер также устанавливает свой собственный лимит на объем данных, которые могут храниться в локальном хранилище для данного домена.

IndexedDB (только браузер)

Схема: indexeddb://

await model.save('indexeddb://my-model');

При этом модель сохраняется в хранилище IndexedDB браузера. Подобно локальному хранилищу, оно сохраняется между обновлениями, но также имеет тенденцию иметь более высокие ограничения на размер хранимых объектов.

Загрузка файлов (только браузер)

Схема: downloads://

await model.save('downloads://my-model');

Это заставит браузер загрузить файлы модели на компьютер пользователя. Будут созданы два файла:

  1. Текстовый файл JSON с именем [my-model].json , который содержит топологию и ссылку на файл весов, описанный ниже.
  2. Двоичный файл, содержащий значения веса, с именем [my-model].weights.bin .

Вы можете изменить имя [my-model] , чтобы получить файлы с другим именем.

Поскольку файл .json указывает на .bin по относительному пути, эти два файла должны находиться в одной папке.

HTTP(S) запрос

Схема: http:// или https://

await model.save('http://model-server.domain/upload')

Это создаст веб-запрос для сохранения модели на удаленном сервере. Вы должны контролировать этот удаленный сервер, чтобы быть уверенным, что он сможет обработать запрос.

Модель будет отправлена ​​на указанный HTTP-сервер посредством POST- запроса. Тело POST имеет формат multipart/form-data и состоит из двух файлов.

  1. Текстовый файл JSON с именем model.json , который содержит топологию и ссылку на файл весов, описанный ниже.
  2. Бинарный файл, содержащий значения веса, с именем model.weights.bin .

Обратите внимание, что имена двух файлов всегда будут такими, как указано выше (имя встроено в функцию). Этот API-документ содержит фрагмент кода Python, который демонстрирует, как можно использовать веб-инфраструктуру flask для обработки запроса, исходящего от save .

Часто вам придется передавать на HTTP-сервер дополнительные аргументы или заголовки запросов (например, для аутентификации или если вы хотите указать папку, в которой должна быть сохранена модель). Вы можете получить детальный контроль над этими аспектами запросов от save , заменив аргумент строки URL в tf.io.browserHTTPRequest . Этот API обеспечивает большую гибкость в управлении HTTP-запросами.

Например:

await model.save(tf.io.browserHTTPRequest(
    'http://model-server.domain/upload',
    {method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));

Собственная файловая система (только Node.js)

Схема: file://

await model.save('file:///path/to/my-model');

При работе на Node.js мы также имеем прямой доступ к файловой системе и можем сохранять модели там. Приведенная выше команда сохранит два файла по path , указанному после scheme .

  1. Текстовый файл JSON с именем [model].json , который содержит топологию и ссылку на файл весов, описанный ниже.
  2. Двоичный файл, содержащий значения веса, с именем [model].weights.bin .

Обратите внимание, что имена двух файлов всегда будут такими, как указано выше (имя встроено в функцию).

Загрузка tf.Model

Учитывая модель, сохраненную с помощью одного из вышеперечисленных методов, мы можем загрузить ее с помощью API tf.loadLayersModel .

Давайте посмотрим, как выглядит код загрузки модели

const model = await tf.loadLayersModel('localstorage://my-model-1');

Несколько вещей, на которые следует обратить внимание:

  • Как и model.save() , функция loadLayersModel принимает строковый аргумент типа URL, который начинается со схемы . Это описывает тип места назначения, из которого мы пытаемся загрузить модель.
  • По схеме следует путь . В приведенном выше примере путь — my-model-1 .
  • Строку, подобную URL-адресу, можно заменить объектом, соответствующим интерфейсу IOHandler.
  • Функция tf.loadLayersModel() является асинхронной.
  • Возвращаемое значение tf.loadLayersModeltf.Model

Ниже мы рассмотрим различные доступные схемы.

Локальное хранилище (только браузер)

Схема: localstorage://

const model = await tf.loadLayersModel('localstorage://my-model');

Это загружает модель с именем my-model из локального хранилища браузера.

IndexedDB (только браузер)

Схема: indexeddb://

const model = await tf.loadLayersModel('indexeddb://my-model');

При этом модель загружается из хранилища IndexedDB браузера.

HTTP(S)

Схема: http:// или https://

const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');

Это загружает модель из конечной точки http. После загрузки файла json функция выполнит запросы на соответствующие файлы .bin , на которые ссылается файл json .

Собственная файловая система (только Node.js)

Схема: file://

const model = await tf.loadLayersModel('file://path/to/my-model/model.json');

При работе на Node.js мы также имеем прямой доступ к файловой системе и можем загружать модели оттуда. Обратите внимание, что в вызове функции выше мы ссылаемся на сам файл model.json (тогда как при сохранении мы указываем папку). Соответствующие файлы .bin должны находиться в той же папке, что и файл json .

Загрузка моделей с помощью IOHandlers

Если приведенных выше схем недостаточно для ваших нужд, вы можете реализовать собственное поведение загрузки с помощью IOHandler . Один IOHandler , предоставляемый TensorFlow.js, — это tf.io.browserFiles , который позволяет пользователям браузера загружать файлы модели в браузер. Дополнительную информацию смотрите в документации .

Сохранение и загрузка моделей с помощью пользовательских обработчиков ввода-вывода

Если приведенных выше схем недостаточно для ваших потребностей в загрузке или сохранении, вы можете реализовать собственное поведение сериализации, реализовав IOHandler .

IOHandler — это объект с методом save и load .

Функция save принимает один параметр, который соответствует интерфейсу ModelArtifacts и должен возвращать обещание, которое разрешается в объект SaveResult .

Функция load не принимает параметров и должна возвращать обещание, которое разрешается в объект ModelArtifacts . Это тот же объект, который передается в save .

См. BrowserHTTPRequest для примера реализации IOHandler.