Zapisz i załaduj modele

TensorFlow.js zapewnia funkcjonalność zapisywania i ładowania modeli, które zostały utworzone za pomocą interfejsu Layers API lub przekonwertowane z istniejących modeli TensorFlow. Mogą to być modele, które sam wytrenowałeś lub które przeszkolili inni. Kluczową zaletą korzystania z interfejsu API Layers jest to, że utworzone za jego pomocą modele można serializować i właśnie to omówimy w tym samouczku.

Ten samouczek skupi się na zapisywaniu i ładowaniu modeli TensorFlow.js (identyfikowanych przez pliki JSON). Możemy również importować modele TensorFlow Python. Ładowanie tych modeli opisano w dwóch poniższych samouczkach:

Zapisz model tf

Zarówno tf.Model jak i tf.Sequential udostępniają funkcję model.save , która pozwala zapisać topologię i wagi modelu.

  • Topologia: Jest to plik opisujący architekturę modelu (tj. jakie operacje wykorzystuje). Zawiera odniesienia do ciężarów modeli, które są przechowywane zewnętrznie.

  • Wagi: Są to pliki binarne przechowujące wagi danego modelu w wydajnym formacie. Zazwyczaj są one przechowywane w tym samym folderze co topologia.

Przyjrzyjmy się jak wygląda kod zapisu modelu

const saveResult = await model.save('localstorage://my-model-1');

Kilka rzeczy, na które warto zwrócić uwagę:

  • Metoda save przyjmuje argument w postaci ciągu znaków przypominający adres URL, który zaczyna się od schematu . Opisuje typ miejsca docelowego, w którym próbujemy zapisać model. W powyższym przykładzie schemat to localstorage://
  • Po schemacie następuje ścieżka . W powyższym przykładzie ścieżka to my-model-1 .
  • Metoda save jest asynchroniczna.
  • Wartość zwracana przez model.save to obiekt JSON, który zawiera informacje, takie jak wielkość bajtów topologii i wagi modelu.
  • Środowisko użyte do zapisania modelu nie ma wpływu na to, które środowiska mogą załadować model. Zapisanie modelu w node.js nie uniemożliwia załadowania go w przeglądarce.

Poniżej przeanalizujemy różne dostępne schematy.

Pamięć lokalna (tylko przeglądarka)

Schemat: localstorage://

await model.save('localstorage://my-model');

Spowoduje to zapisanie modelu pod nazwą my-model w pamięci lokalnej przeglądarki. Będzie się to utrzymywać między odświeżeniami, chociaż pamięć lokalna może zostać wyczyszczona przez użytkowników lub samą przeglądarkę, jeśli problemem stanie się miejsce. Każda przeglądarka ustala także własny limit ilości danych, które mogą być przechowywane w pamięci lokalnej dla danej domeny.

IndexedDB (tylko przeglądarka)

Schemat: indexeddb://

await model.save('indexeddb://my-model');

Spowoduje to zapisanie modelu w pamięci IndexedDB przeglądarki. Podobnie jak pamięć lokalna, utrzymuje się ona pomiędzy odświeżeniami, ma również większe ograniczenia dotyczące rozmiaru przechowywanych obiektów.

Pobieranie plików (tylko przeglądarka)

Schemat: downloads://

await model.save('downloads://my-model');

Spowoduje to, że przeglądarka pobierze pliki modelu na komputer użytkownika. Zostaną utworzone dwa pliki:

  1. Tekstowy plik JSON o nazwie [my-model].json , który zawiera topologię i odniesienie do pliku wag opisanego poniżej.
  2. Plik binarny zawierający wartości wag o nazwie [my-model].weights.bin .

Możesz zmienić nazwę [my-model] aby uzyskać pliki o innej nazwie.

Ponieważ plik .json wskazuje na .bin przy użyciu ścieżki względnej, oba pliki powinny znajdować się w tym samym folderze.

Żądanie HTTP(S).

Schemat: http:// lub https://

await model.save('http://model-server.domain/upload')

Spowoduje to utworzenie żądania internetowego w celu zapisania modelu na zdalnym serwerze. Powinieneś mieć kontrolę nad tym zdalnym serwerem, aby mieć pewność, że jest on w stanie obsłużyć żądanie.

Model zostanie wysłany do określonego serwera HTTP za pośrednictwem żądania POST . Treść testu POST jest w formacie multipart/form-data i składa się z dwóch plików

  1. Tekstowy plik JSON o nazwie model.json , który zawiera topologię i odniesienie do pliku wag opisanego poniżej.
  2. Plik binarny zawierający wartości wag o nazwie model.weights.bin .

Należy pamiętać, że nazwa obu plików będzie zawsze dokładnie taka, jak określono powyżej (nazwa jest wbudowana w funkcję). Ten dokument interfejsu API zawiera fragment kodu Pythona, który demonstruje, w jaki sposób można wykorzystać framework sieciowy Flask do obsługi żądania pochodzącego z save .

Często będziesz musiał przekazać więcej argumentów lub nagłówków żądań do swojego serwera HTTP (np. w celu uwierzytelnienia lub jeśli chcesz określić folder, w którym model powinien zostać zapisany). Możesz uzyskać szczegółową kontrolę nad tymi aspektami żądań z save , zastępując argument ciągu adresu URL w tf.io.browserHTTPRequest . Ten interfejs API zapewnia większą elastyczność w kontrolowaniu żądań HTTP.

Na przykład:

await model.save(tf.io.browserHTTPRequest(
    'http://model-server.domain/upload',
    {method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));

Natywny system plików (tylko Node.js)

Schemat: file://

await model.save('file:///path/to/my-model');

Uruchamiając na Node.js mamy także bezpośredni dostęp do systemu plików i możemy tam zapisywać modele. Powyższe polecenie zapisze dwa pliki w path określonej według scheme .

  1. Tekstowy plik JSON o nazwie [model].json , który zawiera topologię i odniesienie do pliku wag opisanego poniżej.
  2. Plik binarny zawierający wartości wag o nazwie [model].weights.bin .

Należy pamiętać, że nazwa obu plików będzie zawsze dokładnie taka, jak określono powyżej (nazwa jest wbudowana w funkcję).

Ładowanie modelu tf

Mając model, który został zapisany jedną z powyższych metod, możemy go załadować za pomocą API tf.loadLayersModel .

Przyjrzyjmy się jak wygląda kod ładujący model

const model = await tf.loadLayersModel('localstorage://my-model-1');

Kilka rzeczy, na które warto zwrócić uwagę:

  • Podobnie jak model.save() , funkcja loadLayersModel przyjmuje argument w postaci ciągu znaków przypominający adres URL, rozpoczynający się od schematu . Opisuje typ miejsca docelowego, z którego próbujemy załadować model.
  • Po schemacie następuje ścieżka . W powyższym przykładzie ścieżka to my-model-1 .
  • Ciąg przypominający adres URL można zastąpić obiektem pasującym do interfejsu IOHandler.
  • Funkcja tf.loadLayersModel() jest asynchroniczna.
  • Wartość zwracana przez tf.loadLayersModel to tf.Model

Poniżej przeanalizujemy różne dostępne schematy.

Pamięć lokalna (tylko przeglądarka)

Schemat: localstorage://

const model = await tf.loadLayersModel('localstorage://my-model');

Spowoduje to załadowanie modelu o nazwie my-model z pamięci lokalnej przeglądarki.

IndexedDB (tylko przeglądarka)

Schemat: indexeddb://

const model = await tf.loadLayersModel('indexeddb://my-model');

Spowoduje to załadowanie modelu z pamięci IndexedDB przeglądarki.

HTTP(S)

Schemat: http:// lub https://

const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');

Spowoduje to załadowanie modelu z punktu końcowego http. Po załadowaniu pliku json funkcja będzie wysyłać żądania dotyczące odpowiednich plików .bin , do których odwołuje się plik json .

Natywny system plików (tylko Node.js)

Schemat: file://

const model = await tf.loadLayersModel('file://path/to/my-model/model.json');

Działając na Node.js mamy również bezpośredni dostęp do systemu plików i możemy stamtąd ładować modele. Należy pamiętać, że w powyższym wywołaniu funkcji odwołujemy się do samego pliku model.json (podczas zapisywania określamy folder). Odpowiednie pliki .bin powinny znajdować się w tym samym folderze, co plik json .

Ładowanie modeli za pomocą IOHhandlerów

Jeśli powyższe schematy nie są wystarczające dla Twoich potrzeb, możesz zaimplementować niestandardowe zachowanie ładowania za pomocą IOHandler . Jednym z IOHandler udostępnianym przez TensorFlow.js jest tf.io.browserFiles , który umożliwia użytkownikom przeglądarki przesyłanie plików modeli do przeglądarki. Więcej informacji można znaleźć w dokumentacji .

Zapisywanie i ładowanie modeli za pomocą niestandardowych IOHhandlerów

Jeśli powyższe schematy nie są wystarczające dla Twoich potrzeb związanych z ładowaniem lub zapisywaniem, możesz zaimplementować niestandardowe zachowanie serializacji, implementując IOHandler .

IOHandler to obiekt z metodą save i load .

Funkcja save przyjmuje jeden parametr, który jest zgodny z interfejsem ModelArtifacts i powinna zwracać obietnicę, która prowadzi do obiektu SaveResult .

Funkcja load nie przyjmuje żadnych parametrów i powinna zwracać obietnicę, która prowadzi do obiektu ModelArtifacts . Jest to ten sam obiekt, który jest przekazywany do save .

Zobacz BrowserHTTPRequest , aby zapoznać się z przykładem implementacji IOHandler.