モデルの状態を保存および復元する機能は、転移学習や事前トレーニング済みモデルを使用した推論の実行など、多くのアプリケーションにとって不可欠です。これを達成する 1 つの方法は、モデルのパラメーター (重み、バイアスなど) をチェックポイント ファイルまたはディレクトリに保存することです。
このモジュールは、 TensorFlow v2 形式のチェックポイントをロードおよび保存するための高レベルのインターフェイスと、このファイル形式への書き込みおよび読み取りを行う下位レベルのコンポーネントを提供します。
単純なモデルのロードと保存
Checkpointable
プロトコルに準拠することで、コードを追加することなく、多くの単純なモデルをチェックポイントにシリアル化できます。
import Checkpoints
import ImageClassificationModels
extension LeNet: Checkpointable {}
var model = LeNet()
...
try model.writeCheckpoint(to: directory, name: "LeNet")
次に、次を使用して同じチェックポイントを読み取ることができます。
try model.readCheckpoint(from: directory, name: "LeNet")
モデルの読み込みと保存のためのこのデフォルトの実装では、モデル構造内のプロパティの名前に基づく、モデル内の各テンソルのパスベースの命名スキームが使用されます。たとえば、 LeNet-5 モデルの最初の畳み込み内の重みとバイアスは、それぞれconv1/filter
とconv1/bias
という名前で保存されます。ロード時に、チェックポイント リーダーはこれらの名前のテンソルを検索します。
モデルの読み込みと保存のカスタマイズ
どの tensor を保存およびロードするか、またはそれらの tensor の名前付けをより詳細に制御したい場合、 Checkpointable
プロトコルはいくつかのカスタマイズ ポイントを提供します。
特定の型のプロパティを無視するには、 Type.property
の形式で文字列の Set を返す、 ignoredTensorPaths
の実装をモデルに提供します。たとえば、すべてのアテンション レイヤーのscale
プロパティを無視するには、 ["Attention.scale"]
を返すことができます。
デフォルトでは、モデル内の各より深いレベルを区切るためにスラッシュが使用されます。これは、モデルにcheckpointSeparator
を実装し、このセパレータに使用する新しい文字列を提供することでカスタマイズできます。
最後に、テンソルの命名を最大限にカスタマイズするには、 tensorNameMap
実装し、モデル内のテンソルに対して生成されたデフォルトの文字列名をチェックポイント内の目的の文字列名にマッピングする関数を提供できます。最も一般的に、これは他のフレームワークで生成されたチェックポイントと相互運用するために使用されます。各フレームワークには独自の命名規則とモデル構造があります。カスタム マッピング関数を使用すると、これらのテンソルの名前の付け方を最大限にカスタマイズできます。
デフォルトのCheckpointWriter.identityMap
(チェックポイントに自動的に生成されたテンソル パス名を使用する) や、ディクショナリからマッピングを構築できるCheckpointWriter.lookupMap(table:)
関数など、いくつかの標準ヘルパー関数が提供されています。
カスタム マッピングを実現する方法の例については、 GPT-2 モデルを参照してください。このモデルでは、OpenAI のチェックポイントに使用される正確な命名スキームと一致するマッピング関数が使用されています。
CheckpointReader コンポーネントと CheckpointWriter コンポーネント
チェックポイント書き込みの場合、 Checkpointable
プロトコルによって提供される拡張機能は、リフレクションとキーパスを使用してモデルのプロパティを反復し、文字列テンソル パスをテンソル値にマップする辞書を生成します。このディクショナリは、チェックポイントを書き込むディレクトリとともに、基礎となるCheckpointWriter
に提供されます。そのCheckpointWriter
は、そのディクショナリからディスク上のチェックポイントを生成するタスクを処理します。
このプロセスの逆は読み取りであり、 CheckpointReader
にディスク上のチェックポイント ディレクトリの場所が与えられます。次に、そのチェックポイントから読み取り、チェックポイント内のテンソルの名前とその保存された値をマップする辞書を形成します。このディクショナリは、モデル内の現在のテンソルをこのディクショナリ内のテンソルに置き換えるために使用されます。
ロードと保存の両方で、 Checkpointable
プロトコルは、上記のマッピング関数を使用して、テンソルへの文字列パスを、対応するディスク上のテンソル名にマッピングします。
Checkpointable
プロトコルに必要な機能が欠けている場合、またはチェックポイントのロードおよび保存プロセスをより詳細に制御する必要がある場合は、 CheckpointReader
クラスとCheckpointWriter
クラスを単独で使用できます。
TensorFlow v2 チェックポイント形式
このヘッダーで簡単に説明されているように、TensorFlow v2 チェックポイント形式は、TensorFlow モデル チェックポイントの第 2 世代形式です。この第 2 世代形式は 2016 年末から使用されており、v1 チェックポイント形式に比べて多くの点が改善されています。 TensorFlow SavedModel は、その内部で v2 チェックポイントを使用してモデル パラメーターを保存します。
TensorFlow v2 チェックポイントは、次のような構造のディレクトリで構成されます。
checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002
最初のファイルはチェックポイントのメタデータを保存し、残りのファイルはモデルのシリアル化されたパラメーターを保持するバイナリ シャードです。
インデックス メタデータ ファイルには、シャードに含まれるすべてのシリアル化されたテンソルのタイプ、サイズ、場所、文字列名が含まれています。そのインデックス ファイルはチェックポイントの構造的に最も複雑な部分であり、 tensorflow::table
に基づいており、これ自体は SSTable / LevelDB に基づいています。このインデックス ファイルは一連のキーと値のペアで構成されており、キーは文字列、値はプロトコル バッファーです。文字列はソートされ、プレフィックス圧縮されます。たとえば、最初のエントリがconv1/weight
で、次がconv1/bias
の場合、2 番目のエントリはbias
部分のみを使用します。
この全体的なインデックス ファイルは、 Snappy 圧縮を使用して圧縮される場合があります。 SnappyDecompression.swift
ファイルは、圧縮データ インスタンスからの Snappy 解凍のネイティブ Swift 実装を提供します。
インデックス ヘッダー メタデータとテンソル メタデータはプロトコル バッファーとしてエンコードされ、 Swift Protobufを介して直接エンコード/デコードされます。
CheckpointIndexReader
クラスとCheckpointIndexWriter
クラスは、包括的なCheckpointReader
とCheckpointWriter
クラスの一部として、これらのインデックス ファイルの読み込みと保存を処理します。後者は、テンソル データを含む構造的に単純なバイナリ シャードに対して何を読み書きするかを決定するための基礎としてインデックス ファイルを使用します。