Punkty kontrolne modelu

Możliwość zapisywania i przywracania stanu modelu jest niezbędna w wielu zastosowaniach, takich jak uczenie transferowe lub wnioskowanie przy użyciu wstępnie wyszkolonych modeli. Jednym ze sposobów osiągnięcia tego jest zapisanie parametrów modelu (wagi, odchylenia itp.) w pliku lub katalogu punktów kontrolnych.

Moduł ten zapewnia interfejs wysokiego poziomu do ładowania i zapisywania punktów kontrolnych w formacie TensorFlow v2 , a także komponenty niższego poziomu, które zapisują i odczytują z tego formatu pliku.

Ładowanie i zapisywanie prostych modeli

Dzięki zgodności z protokołem Checkpointable wiele prostych modeli można serializować do punktów kontrolnych bez dodatkowego kodu:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

a następnie ten sam punkt kontrolny można odczytać za pomocą:

try model.readCheckpoint(from: directory, name: "LeNet")

Ta domyślna implementacja ładowania i zapisywania modelu będzie używać schematu nazewnictwa opartego na ścieżce dla każdego tensora w modelu, opartego na nazwach właściwości w strukturach modelu. Na przykład wagi i obciążenia w pierwszym splocie w modelu LeNet-5 zostaną zapisane odpowiednio pod nazwami conv1/filter i conv1/bias . Podczas ładowania czytnik punktów kontrolnych będzie wyszukiwał tensory o tych nazwach.

Dostosowywanie ładowania i zapisywania modelu

Jeśli chcesz mieć większą kontrolę nad tym, które tensory są zapisywane i ładowane, lub nad nazewnictwem tych tensorów, protokół Checkpointable oferuje kilka punktów dostosowywania.

Aby zignorować właściwości niektórych typów, możesz zapewnić implementację ignoredTensorPaths w swoim modelu, która zwraca zestaw ciągów w postaci Type.property . Na przykład, aby zignorować właściwość scale na każdej warstwie uwagi, możesz zwrócić ["Attention.scale"] .

Domyślnie do oddzielenia poszczególnych głębszych poziomów w modelu używany jest ukośnik. Można to dostosować, wdrażając checkpointSeparator w swoim modelu i podając nowy ciąg do użycia dla tego separatora.

Na koniec, aby uzyskać największy stopień dostosowania nazewnictwa tensorów, można zaimplementować tensorNameMap i udostępnić funkcję, która odwzorowuje domyślną nazwę ciągu wygenerowaną dla tensora w modelu na żądaną nazwę ciągu w punkcie kontrolnym. Najczęściej będzie to wykorzystywane do współpracy z punktami kontrolnymi wygenerowanymi za pomocą innych frameworków, z których każdy ma własne konwencje nazewnictwa i struktury modeli. Niestandardowa funkcja mapowania zapewnia największy stopień dostosowania nazw tych tensorów.

Dostępne są pewne standardowe funkcje pomocnicze, takie jak domyślna CheckpointWriter.identityMap (która po prostu używa automatycznie wygenerowanej nazwy ścieżki tensora dla punktów kontrolnych) lub funkcja CheckpointWriter.lookupMap(table:) , która może zbudować mapowanie ze słownika.

Przykład tego, jak można wykonać niestandardowe mapowanie, można znaleźć w modelu GPT-2 , który wykorzystuje funkcję mapowania w celu dokładnego dopasowania schematu nazewnictwa używanego dla punktów kontrolnych OpenAI.

Składniki CheckpointReader i CheckpointWriter

W przypadku zapisu w punktach kontrolnych rozszerzenie udostępniane przez protokół Checkpointable wykorzystuje odbicie i ścieżki klawiszy do iteracji po właściwościach modelu i generowania słownika, który odwzorowuje ścieżki tensora ciągów na wartości Tensora. Słownik ten jest dostarczany do bazowego CheckpointWriter wraz z katalogiem, w którym ma zostać zapisany punkt kontrolny. Ten CheckpointWriter obsługuje zadanie generowania punktu kontrolnego na dysku z tego słownika.

Odwrotnością tego procesu jest czytanie, podczas którego CheckpointReader otrzymuje lokalizację katalogu punktu kontrolnego na dysku. Następnie odczytuje dane z tego punktu kontrolnego i tworzy słownik, który odwzorowuje nazwy tensorów w punkcie kontrolnym na ich zapisane wartości. Słownik ten służy do zastępowania bieżących tensorów w modelu tymi z tego słownika.

Zarówno podczas ładowania, jak i zapisywania protokół Checkpointable odwzorowuje ścieżki ciągów na tensory na odpowiadające im nazwy tensorów na dysku, korzystając z opisanej powyżej funkcji mapowania.

Jeżeli w protokole Checkpointable brakuje potrzebnej funkcjonalności lub wymagana jest większa kontrola nad procesem ładowania i zapisywania punktów kontrolnych, można użyć klas CheckpointReader i CheckpointWriter .

Format punktu kontrolnego TensorFlow v2

Format punktu kontrolnego TensorFlow v2, jak krótko opisano w tym nagłówku , jest formatem drugiej generacji punktów kontrolnych modelu TensorFlow. Ten format drugiej generacji jest używany od końca 2016 r. i wprowadzono wiele ulepszeń w stosunku do formatu punktu kontrolnego w wersji 1. TensorFlow SavedModels wykorzystuje w sobie punkty kontrolne v2 do zapisywania parametrów modelu.

Punkt kontrolny TensorFlow v2 składa się z katalogu o strukturze podobnej do poniższej:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

gdzie pierwszy plik przechowuje metadane punktu kontrolnego, a pozostałe pliki to fragmenty binarne przechowujące serializowane parametry modelu.

Plik metadanych indeksu zawiera typy, rozmiary, lokalizacje i nazwy ciągów wszystkich serializowanych tensorów zawartych w fragmentach. Ten plik indeksu jest najbardziej złożoną strukturalnie częścią punktu kontrolnego i jest oparty na tensorflow::table , który sam w sobie jest oparty na SSTable/LevelDB. Ten plik indeksu składa się z szeregu par klucz-wartość, gdzie klucze to ciągi znaków, a wartości to bufory protokołu. Ciągi są sortowane i kompresowane przedrostkami. Na przykład: jeśli pierwszy wpis to conv1/weight , a następny conv1/bias , drugi wpis używa tylko części bias .

Ten ogólny plik indeksu jest czasami kompresowany przy użyciu kompresji Snappy . Plik SnappyDecompression.swift zapewnia natywną implementację Swift dekompresji Snappy ze skompresowanej instancji danych.

Metadane nagłówka indeksu i metadane tensora są kodowane jako bufory protokołu i kodowane/dekodowane bezpośrednio przez Swift Protobuf .

Klasy CheckpointIndexReader i CheckpointIndexWriter obsługują ładowanie i zapisywanie plików indeksu w ramach nadrzędnych klas CheckpointReader i CheckpointWriter . Te ostatnie wykorzystują pliki indeksu jako podstawę do określenia, z czego czytać i zapisywać w strukturalnie prostszych fragmentach binarnych zawierających dane tensora.