모델 체크포인트

모델 상태를 저장하고 복원하는 기능은 전이 학습이나 사전 훈련된 모델을 사용한 추론 수행과 같은 다양한 애플리케이션에 필수적입니다. 이를 수행하는 한 가지 방법은 모델의 매개변수(가중치, 편향 등)를 체크포인트 파일이나 디렉터리에 저장하는 것입니다.

이 모듈은 TensorFlow v2 형식 체크포인트를 로드하고 저장하기 위한 상위 수준 인터페이스와 이 파일 형식에 쓰고 읽는 하위 수준 구성요소를 제공합니다.

간단한 모델 로드 및 저장

Checkpointable 프로토콜을 준수하면 추가 코드 없이 많은 간단한 모델을 체크포인트로 직렬화할 수 있습니다.

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

그런 다음 다음을 사용하여 동일한 체크포인트를 읽을 수 있습니다.

try model.readCheckpoint(from: directory, name: "LeNet")

모델 로드 및 저장을 위한 이 기본 구현은 모델 구조체 내의 속성 이름을 기반으로 하는 모델의 각 텐서에 대해 경로 기반 명명 체계를 사용합니다. 예를 들어, LeNet-5 모델 의 첫 번째 컨볼루션 내의 가중치와 편향은 각각 conv1/filterconv1/bias 라는 이름으로 저장됩니다. 로드할 때 체크포인트 리더는 이러한 이름을 가진 텐서를 검색합니다.

모델 로드 및 저장 사용자 정의

어떤 텐서가 저장되고 로드되는지 또는 해당 텐서의 이름을 지정하는 것에 대해 더 큰 제어권을 갖고 싶다면 Checkpointable 프로토콜은 몇 가지 사용자 정의 지점을 제공합니다.

특정 유형의 속성을 무시하려면 Type.property 형식으로 문자열 집합을 반환하는 ignoredTensorPaths 구현을 모델에 제공할 수 있습니다. 예를 들어 모든 Attention 레이어의 scale 속성을 무시하려면 ["Attention.scale"] 반환할 수 있습니다.

기본적으로 슬래시는 모델의 각 더 깊은 수준을 구분하는 데 사용됩니다. 이는 모델에 checkpointSeparator 구현하고 이 구분 기호에 사용할 새 문자열을 제공하여 사용자 정의할 수 있습니다.

마지막으로, 텐서 이름 지정을 최대한 맞춤화하려면 tensorNameMap 구현하고 모델의 텐서에 대해 생성된 기본 문자열 이름을 체크포인트의 원하는 문자열 이름으로 매핑하는 함수를 제공할 수 있습니다. 가장 일반적으로 이는 다른 프레임워크로 생성된 체크포인트와 상호 운용하는 데 사용되며, 각 프레임워크에는 고유한 명명 규칙과 모델 구조가 있습니다. 사용자 정의 매핑 기능은 이러한 텐서의 이름을 지정하는 방법에 대해 가장 높은 수준의 사용자 정의를 제공합니다.

기본 CheckpointWriter.identityMap (검사점에 대해 자동으로 생성된 텐서 경로 이름을 사용함) 또는 사전에서 매핑을 작성할 수 있는 CheckpointWriter.lookupMap(table:) 함수와 같은 일부 표준 도우미 함수가 제공됩니다.

사용자 정의 매핑을 수행하는 방법에 대한 예는 OpenAI의 체크포인트에 사용되는 정확한 명명 체계와 일치하도록 매핑 기능을 사용하는 GPT-2 모델을 참조하세요.

CheckpointReader 및 CheckpointWriter 구성 요소

체크포인트 작성의 경우 Checkpointable 프로토콜에서 제공하는 확장은 반사 및 키 경로를 사용하여 모델의 속성을 반복하고 문자열 텐서 경로를 Tensor 값에 매핑하는 사전을 생성합니다. 이 사전은 체크포인트를 쓸 디렉터리와 함께 기본 CheckpointWriter 에 제공됩니다. 해당 CheckpointWriter 해당 사전에서 온디스크 체크포인트를 생성하는 작업을 처리합니다.

이 프로세스의 반대는 읽기입니다. 여기서 CheckpointReader 에는 디스크상의 체크포인트 디렉터리 위치가 제공됩니다. 그런 다음 해당 체크포인트에서 읽고 체크포인트 내의 텐서 이름을 저장된 값과 매핑하는 사전을 형성합니다. 이 사전은 모델의 현재 텐서를 이 사전의 텐서로 바꾸는 데 사용됩니다.

로드 및 저장 모두에 대해 Checkpointable 프로토콜은 위에 설명된 매핑 기능을 사용하여 텐서에 대한 문자열 경로를 해당 온디스크 텐서 이름에 매핑합니다.

Checkpointable 프로토콜에 필요한 기능이 부족하거나 체크포인트 로드 및 저장 프로세스에 대해 더 많은 제어가 필요한 경우 CheckpointReaderCheckpointWriter 클래스를 단독으로 사용할 수 있습니다.

TensorFlow v2 체크포인트 형식

이 헤더 에 간략하게 설명된 대로 TensorFlow v2 체크포인트 형식은 TensorFlow 모델 체크포인트의 2세대 형식입니다. 이 2세대 형식은 2016년 말부터 사용되었으며 v1 체크포인트 형식에 비해 여러 가지 개선 사항이 있습니다. TensorFlow SavedModels는 v2 체크포인트를 사용하여 모델 매개변수를 저장합니다.

TensorFlow v2 체크포인트는 다음과 같은 구조의 디렉터리로 구성됩니다.

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

여기서 첫 번째 파일은 체크포인트에 대한 메타데이터를 저장하고 나머지 파일은 모델에 대한 직렬화된 매개변수를 보유하는 바이너리 샤드입니다.

인덱스 메타데이터 파일에는 샤드에 포함된 모든 직렬화된 텐서의 유형, 크기, 위치 및 문자열 이름이 포함되어 있습니다. 해당 인덱스 파일은 체크포인트에서 구조적으로 가장 복잡한 부분이며 SSTable / LevelDB를 기반으로 하는 tensorflow::table 기반으로 합니다. 이 인덱스 파일은 일련의 키-값 쌍으로 구성됩니다. 여기서 키는 문자열이고 값은 프로토콜 버퍼입니다. 문자열은 정렬되고 접두사로 압축됩니다. 예를 들어 첫 번째 항목이 conv1/weight 이고 다음 conv1/bias 인 경우 두 번째 항목은 bias 부분만 사용합니다.

이 전체 인덱스 파일은 때때로 Snappy 압축을 사용하여 압축됩니다. SnappyDecompression.swift 파일은 압축된 Data 인스턴스에서 Snappy 압축 해제의 기본 Swift 구현을 제공합니다.

인덱스 헤더 메타데이터와 텐서 메타데이터는 프로토콜 버퍼로 인코딩되고 Swift Protobuf를 통해 직접 인코딩/디코딩됩니다.

CheckpointIndexReaderCheckpointIndexWriter 클래스는 중요한 CheckpointReaderCheckpointWriter 클래스의 일부로 이러한 인덱스 파일 로드 및 저장을 처리합니다. 후자는 텐서 데이터가 포함된 구조적으로 단순한 바이너리 샤드에서 무엇을 읽고 쓸지 결정하기 위한 기초로 인덱스 파일을 사용합니다.