체크포인트 훈련하기

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

"텐서플로 모델 저장하기" 라는 문구는 보통 둘중 하나를 의미합니다:

  1. Checkpoints, 혹은
  2. SavedModel.

Checkpoint는 모델이 사용한 모든 매개변수(tf.Variable 객체들)의 정확한 값을 캡처합니다. Chekcpoint는 모델에 의해 정의된 연산에 대한 설명을 포함하지 않으므로 일반적으로 저장된 매개변수 값을 사용할 소스 코드를 사용할 수 있을 때만 유용합니다.

반면 SavedModel 형식은 매개변수 값(체크포인트) 외에 모델에 의해 정의된 연산에 대한 일련화된 설명을 포함합니다. 이 형식의 모델은 모델을 만든 소스 코드와 독립적입니다. 따라서 TensorFlow Serving, TensorFlow Lite, TensorFlow.js 또는 다른 프로그래밍 언어(C, C++, Java, Go, Rust, C# 등. TensorFlow APIs)로 배포하기에 적합합니다.

이 가이드는 체크포인트 쓰기 및 읽기를 위한 API들을 다룹니다.

설치

import tensorflow as tf
2022-12-14 21:16:31.009549: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:16:31.009664: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:16:31.009675: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.keras 훈련 API들로부터 저장하기

See the tf.keras guide on saving and restoring .

tf.keras.Model.save_weights 가 텐서플로 CheckPoint를 저장합니다.

net.save_weights('easy_checkpoint')

Checkpoints 작성하기

텐서플로 모델의 지속적인 상태는 tf.Variable 객체에 저장되어 있습니다. 이들은 직접으로 구성할 수 있지만, tf.keras.layers 혹은 tf.keras.Model와 같은 고수준 API들로 만들어 지기도 합니다.

변수를 관리하는 가장 쉬운 방법은 Python 객체에 변수를 연결한 다음 해당 객체를 참조하는 것입니다.

tf.train.Checkpoint, tf.keras.layers.Layer, and tf.keras.Model의 하위클래스들은 해당 속성에 할당된 변수를 자동 추적합니다. 다음 예시는 간단한 선형 model을 구성하고, 모든 model 변수의 값을 포합하는 checkpoint를 씁니다.

Model.save_weights를 사용해 손쉽게 model-checkpoint를 저장할 수 있습니다.

직접 Checkpoint작성하기

설치

tf.train.Checkpoint의 모든 특성을 입증하기 위해서 toy dataset과 optimization step을 정의해야 합니다.

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Checkpoint객체 생성

인위적으로 checkpoint를 만드려면 tf.train.Checkpoint 객체가 필요합니다. Checkpoint하고 싶은 객체의 위치는 객체의 특성으로 설정이 되어 있습니다.

tf.train.CheckpointManager도 다수의 checkpoint를 관리할때 도움이 됩니다

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

훈련하고 model checkpoint작성하기

다음 훈련 루프는 model과 optimizer의 인스턴스를 만든 후 tf.train.Checkpoint 객체에 수집합니다. 이것은 각 데이터 배치에 있는 루프의 훈련 단계를 호출하고, 주기적으로 디스크에 checkpoint를 작성합니다.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 26.76
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 20.18
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 13.62
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 7.20
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 2.25

복구하고 훈련 계속하기

첫 번째 과정 이후 새로운 model과 매니저를 전달할 수 있지만, 일을 마무리 한 정확한 지점에서 훈련을 가져와야 합니다:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.72
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.56
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.45
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.41
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.26

tf.train.CheckpointManager 객체가 이전 checkpoint들을 제거합니다. 위는 가장 최근의 3개 checkpoint만 유지하도록 구성되어 있습니다.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

예를 들어, './tf_ckpts/ckpt-10'같은 경로들은 디스크에 있는 파일이 아닙니다. 대신에 이 경로들은 index 파일과 변수 값들을 담고있는 파일들의 전위 표기입니다. 이 전위 표기들은 CheckpointManager가 상태를 저장하는 하나의 checkpoint 파일 ('./tf_ckpts/checkpoint')에 그룹으로 묶여있습니다.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

작동 원리

텐서플로는 로드되는 객체에서 시작하여 명명된 엣지가 있는 방향 그래프를 통과시켜 변수를 checkpoint된 값과 일치시킵니다. 엣지의 이름들은 특히 기여한 객체의 이름에서 따왔습니다. 예를들면, self.l1 = tf.keras.layers.Dense(5)안의 "l1". tf.train.Checkpoint 이것의 키워드 전달인자 이름을 사용했습니다, 여기에서는 "step" in tf.train.Checkpoint(step=...).

위의 예에서 나온 종속성 그래프는 다음과 같습니다.:

훈련 반복 예시의 의존 그래프 시각화

optimizer는 빨간색으로, regular 변수는 파란색으로, optimizer 슬롯 변수는 주황색으로 표시합니다. 다른 nodes는, 예를 들면 tf.train.Checkpoint, 이 검은색임을 나타냅니다.

슬롯 변수는 옵티마이저 상태의 일부이지만 특정 변수에 대해 생성됩니다. 예를 들어 위의 '' 엣지는 Adam 옵티마이저가 각 변수에 대해 추적하는 모멘텀에 해당합니다. 슬롯 변수는 변수와 옵티마이저가 모두 저장되어 점선으로 된 엣지인 경우에만 체크포인트에 저장됩니다.

tf.train.Checkpoint로 불러온 restore() 오브젝트 큐는그Checkpoint 개체에서 일치하는 방법이 있습니다. 변수 값 복원을 요청한 복원 작업 대기 행렬로 정리합니다. 예를 들어, 우리는 네트워크와 계층을 통해 그것에 대한 하나의 경로를 재구성함으로서 위에서 정의한 모델에서 커널만 로드할 수 있습니다.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.800037  1.7855548 3.146778  3.5596383 3.8949137]

이 새로운 개체에 대한 의존도 그래프는 우리가 위에 적은 더 큰 checkpoint보다 작은 하위 그래프입니다. 이것은 오직 tf.train.Checkpoint에서 checkpoints 셀때 편향과 저장 카운터만 포함합니다.

편향 변수의 서브그래프 시각화

restore() 함수는 선택적으로 확인을 거친 객체의 상태를 반환합니다. 새로 만든 Checkpoint에서 우리가 만든 모든 개체가 복원되어 status.assert_existing_objects_matched가 통과합니다.

status.assert_existing_objects_matched()
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7fb780023460>

checkpoint에는 계층의 커널과 optimizer의 변수를 포함하여 일치하지 않는 많은 개체가 있습니다. status.assert_consumed는 checkpoint와 프로그램이 정확히 일치할 경우에만 통과하고 여기에 예외를 둘 것입니다.

지연된 복원

텐서플로우의 Layer 객체는 입력 형상을 이용할 수 있을 때 변수 생성을 첫 번째 호출로 지연시킬 수 있습니다. 예를 들어, Dense 층의 커널의 모양은 계층의 입력과 출력 형태 모두에 따라 달라지기 때문에, 생성자 인수로 필요한 출력 형태는 그 자체로 변수를 만들기에 충분한 정보가 아닙니다. 예를 들어, Layer 층의 커널의 모양은 계층의 입력과 출력 형태 모두에 따라 달라지기 때문에, 생성자 인수로 필요한 출력 형태는 그 자체로 변수를 만들기에 충분한 정보가 아닙니다.

이 관용구를 지원하기 위해 tf.train.Checkpoint는 아직 일치하는 변수가 없는 복원을 연기합니다.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.575237  4.8758802 4.761698  4.9702144 4.997744 ]]

checkpoints 수동 검사

tf.train.load_checkpoint는 체크포인트 내용에 대한 낮은 수준의 액세스를 제공 CheckpointReader를 반환합니다. 여기에는 각 변수의 키에서 검사점의 각 변수에 대한 모양 및 dtype으로의 매핑이 포함됩니다. 변수의 키는 위에 표시된 그래프와 같이 객체 경로입니다.

참고: 체크포인트에는 더 높은 수준의 구조가 없습니다. 변수의 경로와 값만 알고 있으며 models, layers 또는 이들이 연결되는 방식에 대한 개념이 없습니다.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/_learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/_variables/1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/_variables/2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/_variables/3/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/_variables/4/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

따라서 net.l1.kernel의 값에 관심이 있다면 다음 코드로 이 값을 얻을 수 있습니다.

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

변수 값을 검사할 수 있는 get_tensor 메서드도 제공합니다.

reader.get_tensor(key)
array([[4.575237 , 4.8758802, 4.761698 , 4.9702144, 4.997744 ]],
      dtype=float32)

목록 및 딕셔너리 추적

체크포인트는 속성 중 하나에 설정된 모든 변수 또는 추적 가능한 개체를 '추적'하여 tf.Variable 개체의 값을 저장하고 복원합니다. 저장을 실행할 때 도달 가능한 모든 추적 개체로부터 변수가 재귀적으로 수집됩니다.

self.l1 = tf.keras.layer.Dense(5),와 같은 직접적인 속성 할당은 목록과 사전적 속성에 할당하면 내용이 추적됩니다.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

당신은 래퍼(wrapper) 객체를 목록과 사전에 있음을 알아차릴겁니다. 이러한 래퍼는 기본 데이터 구조의 checkpoint 가능한 버전입니다. 속성 기반 로딩과 마찬가지로, 이러한 래퍼들은 변수의 값이 용기에 추가되는 즉시 복원됩니다.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

추적 가능한 개체는 tf.train.Checkpoint, tf.Module 및 해당 서브 클래스(예: keras.layers.Layerkeras.Model)와 인식된 Python 컨테이너를 포함합니다.

  • dict(및 collections.OrderedDict)
  • list
  • tuple(및 collections.namedtuple, typing.NamedTuple)

다음을 포함하여 기타 컨테이너 유형은 지원하지 않습니다.

  • collections.defaultdict
  • set

다음을 포함한 기타 Python 개체는 무시됩니다.

  • int
  • string
  • float

요약

텐서프로우 객체는 사용하는 변수의 값을 저장하고 복원할 수 있는 쉬운 자동 메커니즘을 제공합니다.