Salvar e carregar modelos

O TensorFlow.js fornece funcionalidade para salvar e carregar modelos que foram criados com a API Layers ou convertidos de modelos existentes do TensorFlow. Podem ser modelos que você mesmo treinou ou treinados por outras pessoas. Um dos principais benefícios de usar a API Layers é que os modelos criados com ela são serializáveis ​​e é isso que exploraremos neste tutorial.

Este tutorial se concentrará em salvar e carregar modelos TensorFlow.js (identificáveis ​​por arquivos JSON). Também podemos importar modelos TensorFlow Python. O carregamento desses modelos é abordado nos dois tutoriais a seguir:

Salvar um tf.Model

tf.Model e tf.Sequential fornecem uma função model.save que permite salvar a topologia e os pesos de um modelo.

  • Topologia: Este é um arquivo que descreve a arquitetura de um modelo (ou seja, quais operações ele utiliza). Contém referências aos pesos dos modelos que são armazenados externamente.

  • Pesos: São arquivos binários que armazenam os pesos de um determinado modelo em um formato eficiente. Geralmente eles são armazenados na mesma pasta da topologia.

Vamos dar uma olhada na aparência do código para salvar um modelo

const saveResult = await model.save('localstorage://my-model-1');

Algumas coisas a serem observadas:

  • O método save usa um argumento de string semelhante a uma URL que começa com um esquema . Isto descreve o tipo de destino no qual estamos tentando salvar um modelo. No exemplo acima o esquema é localstorage://
  • O esquema é seguido por um caminho . No exemplo acima, o caminho é my-model-1 .
  • O método save é assíncrono.
  • O valor de retorno de model.save é um objeto JSON que carrega informações como tamanhos de bytes da topologia e pesos do modelo.
  • O ambiente usado para salvar o modelo não afeta quais ambientes podem carregar o modelo. Salvar um modelo em node.js não impede que ele seja carregado no navegador.

Abaixo examinaremos os diferentes esquemas disponíveis.

Armazenamento local (somente navegador)

Esquema: localstorage://

await model.save('localstorage://my-model');

Isso salva um modelo com o nome my-model no armazenamento local do navegador. Isso persistirá entre as atualizações, embora o armazenamento local possa ser limpo pelos usuários ou pelo próprio navegador se o espaço se tornar uma preocupação. Cada navegador também define seu próprio limite de quantidade de dados que podem ser armazenados no armazenamento local para um determinado domínio.

IndexedDB (somente navegador)

Esquema: indexeddb://

await model.save('indexeddb://my-model');

Isso salva um modelo no armazenamento IndexedDB do navegador. Assim como o armazenamento local, ele persiste entre as atualizações, mas também tende a ter limites maiores no tamanho dos objetos armazenados.

Downloads de arquivos (somente navegador)

Esquema: downloads://

await model.save('downloads://my-model');

Isso fará com que o navegador baixe os arquivos do modelo para a máquina do usuário. Serão produzidos dois arquivos:

  1. Um arquivo JSON de texto denominado [my-model].json , que carrega a topologia e a referência ao arquivo de pesos descrito abaixo.
  2. Um arquivo binário que contém os valores de peso denominado [my-model].weights.bin .

Você pode alterar o nome [my-model] para obter arquivos com um nome diferente.

Como o arquivo .json aponta para .bin usando um caminho relativo, os dois arquivos devem estar na mesma pasta.

Solicitação HTTP(S)

Esquema: http:// ou https://

await model.save('http://model-server.domain/upload')

Isso criará uma solicitação da web para salvar um modelo em um servidor remoto. Você deve estar no controle desse servidor remoto para garantir que ele seja capaz de lidar com a solicitação.

O modelo será enviado ao servidor HTTP especificado por meio de uma solicitação POST . O corpo do POST está no formato multipart/form-data e consiste em dois arquivos

  1. Um arquivo JSON de texto denominado model.json , que carrega a topologia e a referência ao arquivo de pesos descrito abaixo.
  2. Um arquivo binário que contém os valores de peso denominado model.weights.bin .

Observe que o nome dos dois arquivos sempre será exatamente como especificado acima (o nome está embutido na função). Este documento da API contém um trecho de código Python que demonstra como alguém pode usar a estrutura da web flask para lidar com a solicitação originada de save .

Muitas vezes você terá que passar mais argumentos ou solicitar cabeçalhos para o seu servidor HTTP (por exemplo, para autenticação ou se quiser especificar uma pasta na qual o modelo deve ser salvo). Você pode obter controle refinado sobre esses aspectos das solicitações ao save , substituindo o argumento da string de URL em tf.io.browserHTTPRequest . Esta API oferece maior flexibilidade no controle de solicitações HTTP.

Por exemplo:

await model.save(tf.io.browserHTTPRequest(
    'http://model-server.domain/upload',
    {method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));

Sistema de arquivos nativo (somente Node.js)

Esquema: file://

await model.save('file:///path/to/my-model');

Ao executar em Node.js também temos acesso direto ao sistema de arquivos e podemos salvar modelos lá. O comando acima salvará dois arquivos no path especificado após o scheme .

  1. Um arquivo JSON de texto denominado [model].json , que carrega a topologia e a referência ao arquivo de pesos descrito abaixo.
  2. Um arquivo binário que contém os valores de peso denominado [model].weights.bin .

Observe que o nome dos dois arquivos sempre será exatamente como especificado acima (o nome está embutido na função).

Carregando um tf.Model

Dado um modelo que foi salvo usando um dos métodos acima, podemos carregá-lo usando a API tf.loadLayersModel .

Vamos dar uma olhada na aparência do código para carregar um modelo

const model = await tf.loadLayersModel('localstorage://my-model-1');

Algumas coisas a serem observadas:

  • Assim como model.save() , a função loadLayersModel usa um argumento de string semelhante a uma URL que começa com um esquema . Isso descreve o tipo de destino do qual estamos tentando carregar um modelo.
  • O esquema é seguido por um caminho . No exemplo acima, o caminho é my-model-1 .
  • A string semelhante a url pode ser substituída por um objeto que corresponda à interface IOHandler.
  • A função tf.loadLayersModel() é assíncrona.
  • O valor de retorno de tf.loadLayersModel é tf.Model

Abaixo examinaremos os diferentes esquemas disponíveis.

Armazenamento local (somente navegador)

Esquema: localstorage://

const model = await tf.loadLayersModel('localstorage://my-model');

Isso carrega um modelo chamado my-model do armazenamento local do navegador.

IndexedDB (somente navegador)

Esquema: indexeddb://

const model = await tf.loadLayersModel('indexeddb://my-model');

Isso carrega um modelo do armazenamento IndexedDB do navegador.

HTTP(S)

Esquema: http:// ou https://

const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');

Isso carrega um modelo de um endpoint http. Depois de carregar o arquivo json , a função fará solicitações para os arquivos .bin correspondentes aos quais o arquivo json faz referência.

Sistema de arquivos nativo (somente Node.js)

Esquema: file://

const model = await tf.loadLayersModel('file://path/to/my-model/model.json');

Ao executar em Node.js também temos acesso direto ao sistema de arquivos e podemos carregar modelos a partir daí. Observe que na chamada de função acima fazemos referência ao próprio arquivo model.json (enquanto ao salvar especificamos uma pasta). Os arquivos .bin correspondentes devem estar na mesma pasta que o arquivo json .

Carregando modelos com IOHandlers

Se os esquemas acima não forem suficientes para suas necessidades, você poderá implementar um comportamento de carregamento personalizado com um IOHandler . Um IOHandler fornecido pelo TensorFlow.js é o tf.io.browserFiles , que permite aos usuários do navegador fazer upload de arquivos de modelo no navegador. Veja a documentação para mais informações.

Salvando e carregando modelos com IOHandlers personalizados

Se os esquemas acima não forem suficientes para suas necessidades de carregamento ou salvamento, você poderá implementar um comportamento de serialização personalizado implementando um IOHandler .

Um IOHandler é um objeto com um método save e load .

A função save usa um parâmetro que corresponde à interface ModelArtifacts e deve retornar uma promessa que resolve para um objeto SaveResult .

A função load não aceita parâmetros e deve retornar uma promessa que resolve para um objeto ModelArtifacts . Este é o mesmo objeto passado para save .

Consulte BrowserHTTPRequest para obter um exemplo de como implementar um IOHandler.