정확성 및 수치적 등가성 검증하기

TensorFlow.org에서보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기

TensorFlow 코드를 TF1.x에서 TF2로 마이그레이션할 때에는 마이그레이션된 코드가 TF1.x에서와 동일한 방식으로 TF2에서 동작하는지 확인하는 것이 좋습니다.

이 가이드에서는 tf.keras.layers.Layer 메서드에 적용된 tf.compat.v1.keras.utils.track_tf1_style_variables 모델링 shim을 사용한 마이그레이션 코드 예제를 다룹니다. TF2 모델링 shim에 대한 자세한 내용은 모델 매핑 가이드를 읽어보세요.

이 가이드에서는 다음에 사용할 수 있는 접근 방식을 자세히 설명합니다.

  • 마이그레이션된 코드를 사용하여 훈련 모델에서 얻은 결과의 정확성 검증
  • TensorFlow 버전 전체에서 코드의 수치적 등가성 검증

설치하기

pip uninstall -y -q tensorflow
# Install tf-nightly as the DeterministicRandomTestTool is available only in
# Tensorflow 2.8
pip install -q tf-nightly
pip install -q tf_slim
import tensorflow as tf
import tensorflow.compat.v1 as v1

import numpy as np
import tf_slim as slim
import sys


from contextlib import contextmanager
2022-12-14 20:25:50.619809: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
!git clone --depth=1 https://github.com/tensorflow/models.git
import models.research.slim.nets.inception_resnet_v2 as inception
Cloning into 'models'...
remote: Enumerating objects: 3590, done.[K
remote: Counting objects: 100% (3590/3590), done.[K
remote: Compressing objects: 100% (3005/3005), done.[K
remote: Total 3590 (delta 943), reused 1501 (delta 531), pack-reused 0[K
Receiving objects: 100% (3590/3590), 47.08 MiB | 33.50 MiB/s, done.
Resolving deltas: 100% (943/943), done.

중요한 순방향 전달 코드 청크를 shim에 넣을 경우 TF1.x에서와 동일한 방식으로 동작하는지 알고 싶을 것입니다. 예를 들어 전체 TF-Slim Inception-Resnet-v2 모델을 다음과 같이 shim에 넣는 것을 고려할 수 있습니다.

# TF1 Inception resnet v2 forward pass based on slim layers
def inception_resnet_v2(inputs, num_classes, is_training):
  with slim.arg_scope(
    inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
    return inception.inception_resnet_v2(inputs, num_classes, is_training=is_training)
class InceptionResnetV2(tf.keras.layers.Layer):
  """Slim InceptionResnetV2 forward pass as a Keras layer"""

  def __init__(self, num_classes, **kwargs):
    super().__init__(**kwargs)
    self.num_classes = num_classes

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    is_training = training or False 

    # Slim does not accept `None` as a value for is_training,
    # Keras will still pass `None` to layers to construct functional models
    # without forcing the layer to always be in training or in inference.
    # However, `None` is generally considered to run layers in inference.

    with slim.arg_scope(
        inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
      return inception.inception_resnet_v2(
          inputs, self.num_classes, is_training=is_training)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_35328/2131234657.py:8: The name tf.keras.utils.track_tf1_style_variables is deprecated. Please use tf.compat.v1.keras.utils.track_tf1_style_variables instead.

실제로 이 레이어는 즉시 잘 작동합니다(정확한 정규화 손실 추적 기능 포함).

그러나 이것은 여러분이 당연히 여기고 싶은 일이 아닙니다. 아래의 단계를 따라 실제로 TF1.x과 동일하게 동작하는지 확인하고 완벽한 수치적 등가성을 관찰을 진행합니다. 이러한 단계들은 순방향 전달의 어떤 부분이 TF1.x에서 발산을 일으키는지 삼각측량하는 데 도움이 될 수 있습니다(모델의 다른 부분과 반대로 모델 순방향 전달에서 발산이 발생하는지 식별).

1단계: 변수가 한 번만 생성되는지 확인

매번 실수로 새 변수를 만들고 사용하는 대신 각 호출에서 변수를 재사용하는 방식으로 모델을 올바르게 구축했는지 가장 먼저 확인해야 합니다. 예를 들어 모델이 새 Keras 레이어를 생성하거나 각 순방향 전달 호출에서 tf.Variable을 호출하는 경우 변수 캡처에 실패하고 매번 새 변수를 생성할 가능성이 큽니다.

다음은 모델이 새 변수를 생성하는 시기를 감지하고 모델의 어느 부분이 이를 수행하는지 디버그하는 데 사용할 수 있는 두 가지 컨텍스트 관리자 범위입니다.

@contextmanager
def assert_no_variable_creations():
  """Assert no variables are created in this context manager scope."""
  def invalid_variable_creator(next_creator, **kwargs):
    raise ValueError("Attempted to create a new variable instead of reusing an existing one. Args: {}".format(kwargs))

  with tf.variable_creator_scope(invalid_variable_creator):
    yield

@contextmanager
def catch_and_raise_created_variables():
  """Raise all variables created within this context manager scope (if any)."""
  created_vars = []
  def variable_catcher(next_creator, **kwargs):
    var = next_creator(**kwargs)
    created_vars.append(var)
    return var

  with tf.variable_creator_scope(variable_catcher):
    yield
  if created_vars:
    raise ValueError("Created vars:", created_vars)

첫 번째 범위(assert_no_variable_creations())는 범위 내에서 변수를 생성하려고 하면 즉시 오류를 발생시킵니다. 이렇게 하면 스택 추적(stacktrace)을 검사하고 대화형 디버깅을 사용하여 기존 변수를 재사용하는 대신 변수를 생성한 코드 줄이 무엇인지 정확하게 파악할 수 있습니다.

두 번째 범위(catch_and_raise_created_variables())는 변수가 생성된 경우 범위의 끝단에서 예외를 발생시킵니다. 이 예외에는 범위에서 생성된 모든 변수 목록이 포함됩니다. 이는 일반적인 패턴을 발견할 수 있는 경우 모델이 생성하는 모든 가중치 세트가 무엇인지 파악하는 데 유용합니다. 그러나 이러한 변수가 생성된 정확한 코드 줄을 식별하는 데는 덜 유용합니다.

아래의 두 범위를 모두 사용하여 shim 기반 InceptionResnetV2 레이어가 첫 번째 호출 이후 새 변수를 생성하지 않는지(아마도 변수를 재사용) 확인합니다.

model = InceptionResnetV2(1000)
height, width = 299, 299
num_classes = 1000

inputs = tf.ones( (1, height, width, 3))
# Create all weights on the first call
model(inputs)

# Verify that no new weights are created in followup calls
with assert_no_variable_creations():
  model(inputs)
with catch_and_raise_created_variables():
  model(inputs)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:332: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  warnings.warn('`tf.layers.flatten` is deprecated and '

아래의 예제에서 이러한 데코레이터가 기존 가중치를 재사용하는 대신 매번 새 가중치를 잘못 생성하는 레이어에서 어떻게 작동하는지 관찰합니다.

class BrokenScalingLayer(tf.keras.layers.Layer):
  """Scaling layer that incorrectly creates new weights each time:"""

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    var = tf.Variable(initial_value=2.0)
    bias = tf.Variable(initial_value=2.0, name='bias')
    return inputs * var + bias
model = BrokenScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)

try:
  with assert_no_variable_creations():
    model(inputs)
except ValueError as err:
  import traceback
  traceback.print_exc()
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_35328/1128777590.py", line 7, in <module>
    model(inputs)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmpfs/tmp/ipykernel_35328/3224979076.py", line 6, in call
    var = tf.Variable(initial_value=2.0)
  File "/tmpfs/tmp/ipykernel_35328/1829430118.py", line 5, in invalid_variable_creator
    raise ValueError("Attempted to create a new variable instead of reusing an existing one. Args: {}".format(kwargs))
ValueError: Exception encountered when calling layer 'broken_scaling_layer' (type BrokenScalingLayer).

Attempted to create a new variable instead of reusing an existing one. Args: {'initial_value': 2.0, 'trainable': None, 'validate_shape': True, 'caching_device': None, 'name': None, 'variable_def': None, 'dtype': None, 'import_scope': None, 'constraint': None, 'synchronization': <VariableSynchronization.AUTO: 0>, 'aggregation': <VariableAggregation.NONE: 0>, 'shape': None, 'experimental_enable_variable_lifting': None}

Call arguments received by layer 'broken_scaling_layer' (type BrokenScalingLayer):
  • inputs=tf.Tensor(shape=(1, 299, 299, 3), dtype=float32)
model = BrokenScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)

try:
  with catch_and_raise_created_variables():
    model(inputs)
except ValueError as err:
  print(err)
('Created vars:', [<tf.Variable 'broken_scaling_layer_1/Variable:0' shape=() dtype=float32, numpy=2.0>, <tf.Variable 'broken_scaling_layer_1/bias:0' shape=() dtype=float32, numpy=2.0>])

가중치를 한 번만 생성한 다음 매번 다시 사용하도록 레이어를 수정할 수 있습니다.

class FixedScalingLayer(tf.keras.layers.Layer):
  """Scaling layer that incorrectly creates new weights each time:"""
  def __init__(self):
    super().__init__()
    self.var = None
    self.bias = None

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    if self.var is None:
      self.var = tf.Variable(initial_value=2.0)
      self.bias = tf.Variable(initial_value=2.0, name='bias')
    return inputs * self.var + self.bias

model = FixedScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)

with assert_no_variable_creations():
  model(inputs)
with catch_and_raise_created_variables():
  model(inputs)

문제 해결

모델이 기존 가중치를 재사용하는 대신 실수로 새 가중치를 생성하는 몇 가지 일반적인 이유는 다음과 같습니다.

  1. 이미 생성된 tf.Variables를 재사용하지 않고 명시적인 tf.Variable 호출을 사용합니다. 생성되지 않았는지 먼저 확인한 다음 기존 항목을 재사용하여 이 문제를 해결해야 합니다.
  2. 이는 tf.compat.v1.layers와 달리 매번 순방향 전달에서 직접 Keras 레이어 또는 모델을 생성합니다. 생성되지 않았는지 먼저 확인한 다음 기존 항목을 재사용하여 이 문제를 해결해야 합니다.
  3. tf.compat.v1.layers을 기반으로 구축되었지만 모든 compat.v1.layers에 명시적 이름을 할당하거나 이름이 지정된 variable_scope 내에서 compat.v1.layer 사용을 래핑하는 데 실패하여 각 모델 호출에서 자동 생성된 레이어 이름이 증가되도록 했습니다. 이 문제를 해결하려면 모든 tf.compat.v1.layers 사용을 래핑하는 shim으로 데코레이팅된 메서드 안에 이름이 지정된 tf.compat.v1.variable_scope를 넣어야 합니다.

2단계: 변수 개수, 이름 및 형상이 일치하는지 확인

두 번째 단계는 TF2에서 실행하는 레이어가 TF1.x에서 해당 코드와 동일한 형상과 수로 가중치를 생성하는지 확인하는 것입니다.

일치하는 수동으로 확인하는 방식과 아래와 같이 단위 테스트에서 프로그래밍 방식으로 확인하는 방식을 혼합하여 이 작업을 수행할 수 있습니다.

# Build the forward pass inside a TF1.x graph, and 
# get the counts, shapes, and names of the variables
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
  height, width = 299, 299
  num_classes = 1000
  inputs = tf.ones( (1, height, width, 3))

  out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)

  tf1_variable_names_and_shapes = {
      var.name: (var.trainable, var.shape) for var in tf.compat.v1.global_variables()}
  num_tf1_variables = len(tf.compat.v1.global_variables())
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '

다음으로 TF2의 Shim으로 래핑한 레이어에 대해 동일한 작업을 수행합니다. 가중치를 가져오기 전에 모델이 여러 번 호출된다는 점에 유의해야 합니다. 이는 변수 재사용을 효과적으로 테스트하기 위해 수행됩니다.

height, width = 299, 299
num_classes = 1000

model = InceptionResnetV2(num_classes)
# The weights will not be created until you call the model

inputs = tf.ones( (1, height, width, 3))
# Call the model multiple times before checking the weights, to verify variables
# get reused rather than accidentally creating additional variables
out, endpoints = model(inputs, training=False)
out, endpoints = model(inputs, training=False)

# Grab the name: shape mapping and the total number of variables separately,
# because in TF2 variables can be created with the same name
num_tf2_variables = len(model.variables)
tf2_variable_names_and_shapes = {
    var.name: (var.trainable, var.shape) for var in model.variables}
# Verify that the variable counts, names, and shapes all match:
assert num_tf1_variables == num_tf2_variables
assert tf1_variable_names_and_shapes == tf2_variable_names_and_shapes

Shim 기반 InceptionResnetV2 레이어는 이 테스트를 통과합니다. 그러나 일치하지 않는 경우 diff(텍스트 또는 기타)를 통해 실행하여 차이점을 확인할 수 있습니다.

이는 모델의 어떤 부분이 예상대로 작동하지 않는지에 대한 단서를 제공할 수 있습니다. 즉시 실행을 통해 pdb, 대화형 디버깅 및 중단점을 사용하고 모델에서 의심스러워 보이는 부분을 깊게 살펴보고 잘못된 부분을 더 깊이 디버깅할 수 있습니다.

문제 해결

  • 명시적 tf.Variable 호출 및 Keras 레이어/모델에 의해 직접 생성된 모든 변수의 이름에 세심한 주의를 기울여야 합니다. 변수 이름 생성 의미 체계는 다른 모든 것이 제대로 작동하더라도 즉시 실행과 tf.function과 같이 TF1.x 그래프와 TF2 기능 사이에 약간 다른 부분이 있을 수 있기 때문입니다. 이 경우 약간 다른 이름 지정 의미 체계를 고려하여 테스트를 조정하도록 합니다.

  • 때로는 훈련 루프의 순방향 전달에서 생성한 tf.Variable, tf.keras.layers.Layer 또는 tf.keras.Model이 TF1.x의 변수 모음에 의해 캡처되었음에도 불구하고 TF2 변수 목록에서 누락된 것을 발견할 수 있습니다. 순방향 전달이 생성하는 변수/레이어/모델을 모델의 인스턴스 속성에 할당하여 이 문제를 해결해야 합니다. 자세한 정보는 여기를 참조합니다.

3단계: 모든 변수 재설정, 모든 임의성이 비활성화된 수치적 등가성 확인

다음 단계는 난수 생성이 포함되지 않도록 모델을 수정할 때(예: 추론 작업 수행하는 동안) 실제 출력과 정규화 손실 추적 모두에 대한 수치적 동등성을 확인하는 것입니다.

이를 수행하는 정확한 방법은 특정 모델에 따라 다를 수 있지만 대부분의 모델(예: 이 모델)에서 다음과 같이 수행할 수 있습니다.

  1. 임의성이 없는 동일한 값으로 가중치를 초기화합니다. 값을 생성한 후에 고정 값으로 재설정하여 이 작업을 수행할 수 있습니다.
  2. 임의성의 소스가 될 수 있는 드롭아웃 레이어가 트리거되지 않도록 모델을 추론 모드에서 실행합니다.

다음 코드는 이러한 방식으로 TF1.x와 TF2 결과를 비교하는 방법을 설명합니다.

graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
  height, width = 299, 299
  num_classes = 1000
  inputs = tf.ones( (1, height, width, 3))

  out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)

  # Rather than running the global variable initializers,
  # reset all variables to a constant value
  var_reset = tf.group([var.assign(tf.ones_like(var) * 0.001) for var in tf.compat.v1.global_variables()])
  sess.run(var_reset)

  # Grab the outputs & regularization loss
  reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
  tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
  tf1_output = sess.run(out)

print("Regularization loss:", tf1_regularization_loss)
tf1_output[0][:5]
Regularization loss: 0.001182976
array([0.00299837, 0.00299837, 0.00299837, 0.00299837, 0.00299837],
      dtype=float32)

TF2 결과를 구합니다.

height, width = 299, 299
num_classes = 1000

model = InceptionResnetV2(num_classes)

inputs = tf.ones((1, height, width, 3))
# Call the model once to create the weights
out, endpoints = model(inputs, training=False)

# Reset all variables to the same fixed value as above, with no randomness
for var in model.variables:
  var.assign(tf.ones_like(var) * 0.001)
tf2_output, endpoints = model(inputs, training=False)

# Get the regularization loss
tf2_regularization_loss = tf.math.add_n(model.losses)

print("Regularization loss:", tf2_regularization_loss)
tf2_output[0][:5]
Regularization loss: tf.Tensor(0.0011829757, shape=(), dtype=float32)
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.00299837, 0.00299837, 0.00299837, 0.00299837, 0.00299837],
      dtype=float32)>
# Create a dict of tolerance values
tol_dict={'rtol':1e-06, 'atol':1e-05}
# Verify that the regularization loss and output both match
# when we fix the weights and avoid randomness by running inference:
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)

임의성 소스를 제거하면 TF1.x와 TF2 사이의 숫자가 일치하게 되고, TF2 호환 InceptionResnetV2 레이어가 테스트를 통과합니다.

자신의 모델에 대해 결과가 분기되는 것을 관찰하는 경우 인쇄 또는 pdb 및 대화형 디버깅을 사용하여 결과가 분기되기 시작하는 위치와 이유를 식별할 수 있습니다. 즉시 실행을 통해 훨씬 쉽게 이 작업을 수행할 수 있습니다. 또한, 고정된 중간 입력에서 모델의 작은 부분만 실행하고 분기가 발생하는 위치를 분리하기 위해 제거 접근 방식을 사용할 수 있습니다.

편리상 많은 슬림형 네트워크(및 기타 모델)는 프로브할 수 있는 중간 엔드포인트도 노출합니다.

4단계: 난수 생성 정렬, 훈련 및 추론 모두에서 수치적 동가성 확인

마지막 단계는 변수 초기화 및 순방향 전달 자체(예: 순방향 전달 중에 드롭아웃 레이어)에서 난수 생성을 고려할 때에도 TF2 모델이 수치적으로 TF1.x 모델과 일치하는지 확인하는 것입니다.

아래의 테스트 도구를 사용하여 TF1.x 그래프/세션과 즉시 실행 간에 난수 생성 의미 체계가 일치하도록 함으로써 이 작업을 수행할 수 있습니다.

TF1 레거시 그래프/세션 및 TF2 즉시 실행은 서로 다른 상태 저장 난수 생성 의미 체계를 사용합니다.

tf.compat.v1.Session에서는 시드가 지정되지 않은 경우, 무작위 연산이 추가되는 시점에 그래프에 있는 연산의 수와 그래프의 실행 횟수에 따라 난수 생성이 달라집니다. 즉시 실행에서 상태 저장 난수 생성은 전역 시드, 연산 래덤 시드 및 제공된 임의 시드가 있는 연산이 실행되는 횟수에 따라 달라집니다. 자세한 정보는 tf.random.set_seed를 참조하세요.

다음 v1.keras.utils.DeterministicRandomTestTool 클래스는 상태 저장 임의 작업이 TF1 그래프/세션과 즉시 실행 모두에서 동일한 시드를 사용하도록 할 수 있는 컨텍스트 관리자 scope()를 제공합니다.

이 도구는 두 가지 테스트 모드를 제공합니다.

  1. 호출 횟수에 관계없이 모든 단일 연산에 동일한 시드를 사용하는 constant
  2. 이전에 관찰된 상태 저장 임의 작업 수를 작업 시드로 사용하는 num_random_ops

이는 변수 생성 및 초기화에 사용되는 상태 저장 임의 작업과 계산에 사용되는 상태 저장 임의 작업(예: 드롭아웃 레이어)에 모두 적용됩니다.

세 개의 임의 텐서를 생성하여 이 도구를 사용할 경우 세션과 즉시 실행 간에 상태 저장 난수 생성이 일치되도록 하는 방법을 보여줍니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
  graph = tf.Graph()
  with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
    a = tf.random.uniform(shape=(3,1))
    a = a * 3
    b = tf.random.uniform(shape=(3,3))
    b = b * 3
    c = tf.random.uniform(shape=(3,3))
    c = c * 3
    graph_a, graph_b, graph_c = sess.run([a, b, c])

graph_a, graph_b, graph_c
(array([[2.5063772],
        [2.7488918],
        [1.4839486]], dtype=float32),
 array([[2.5063772, 2.7488918, 1.4839486],
        [1.5633398, 2.1358476, 1.3693532],
        [0.3598416, 1.8287641, 2.5314465]], dtype=float32),
 array([[2.5063772, 2.7488918, 1.4839486],
        [1.5633398, 2.1358476, 1.3693532],
        [0.3598416, 1.8287641, 2.5314465]], dtype=float32))
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
  a = tf.random.uniform(shape=(3,1))
  a = a * 3
  b = tf.random.uniform(shape=(3,3))
  b = b * 3
  c = tf.random.uniform(shape=(3,3))
  c = c * 3

a, b, c
(<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
 array([[2.5063772],
        [2.7488918],
        [1.4839486]], dtype=float32)>,
 <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
 array([[2.5063772, 2.7488918, 1.4839486],
        [1.5633398, 2.1358476, 1.3693532],
        [0.3598416, 1.8287641, 2.5314465]], dtype=float32)>,
 <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
 array([[2.5063772, 2.7488918, 1.4839486],
        [1.5633398, 2.1358476, 1.3693532],
        [0.3598416, 1.8287641, 2.5314465]], dtype=float32)>)
# Demonstrate that the generated random numbers match
np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)
np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict)
np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)

그러나 constant 모드에서는 bc가 동일한 시드로 생성되었고 동일한 형상을 갖기 때문에 정확히 같은 값을 갖게 됩니다.

np.testing.assert_allclose(b.numpy(), c.numpy(), **tol_dict)

주문 추적

constant 모드에서 일치하는 일부 난수가 수치적 등가성 테스트의 신뢰도를 감소시키는 것이 걱정되는 경우(예: 여러 가중치가 동일한 초기화를 수행하는 경우) num_random_ops 모드를 사용하여 이를 방지할 수 있습니다. num_random_ops 모드에서 생성한 난수는 프로그램의 임의 연산 순서에 따라 달라집니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  graph = tf.Graph()
  with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
    a = tf.random.uniform(shape=(3,1))
    a = a * 3
    b = tf.random.uniform(shape=(3,3))
    b = b * 3
    c = tf.random.uniform(shape=(3,3))
    c = c * 3
    graph_a, graph_b, graph_c = sess.run([a, b, c])

graph_a, graph_b, graph_c
(array([[2.5063772],
        [2.7488918],
        [1.4839486]], dtype=float32),
 array([[0.45038545, 1.9197761 , 2.4536333 ],
        [1.0371652 , 2.9898582 , 1.924583  ],
        [0.25679827, 1.6579313 , 2.8418403 ]], dtype=float32),
 array([[2.9634383 , 1.0862181 , 2.6042497 ],
        [0.70099247, 2.3920312 , 1.0470468 ],
        [0.18173039, 0.8359269 , 1.0508587 ]], dtype=float32))
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  a = tf.random.uniform(shape=(3,1))
  a = a * 3
  b = tf.random.uniform(shape=(3,3))
  b = b * 3
  c = tf.random.uniform(shape=(3,3))
  c = c * 3

a, b, c
(<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
 array([[2.5063772],
        [2.7488918],
        [1.4839486]], dtype=float32)>,
 <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
 array([[0.45038545, 1.9197761 , 2.4536333 ],
        [1.0371652 , 2.9898582 , 1.924583  ],
        [0.25679827, 1.6579313 , 2.8418403 ]], dtype=float32)>,
 <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
 array([[2.9634383 , 1.0862181 , 2.6042497 ],
        [0.70099247, 2.3920312 , 1.0470468 ],
        [0.18173039, 0.8359269 , 1.0508587 ]], dtype=float32)>)
# Demonstrate that the generated random numbers match
np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)
np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict )
np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)
# Demonstrate that with the 'num_random_ops' mode,
# b & c took on different values even though
# their generated shape was the same
assert not np.allclose(b.numpy(), c.numpy(), **tol_dict)

그러나 이 모드로 진행하는 난수 생성은 프로그램 순서에 민감하므로 다음에 생성되는 난수는 일치하지 않습니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  a = tf.random.uniform(shape=(3,1))
  a = a * 3
  b = tf.random.uniform(shape=(3,3))
  b = b * 3

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  b_prime = tf.random.uniform(shape=(3,3))
  b_prime = b_prime * 3
  a_prime = tf.random.uniform(shape=(3,1))
  a_prime = a_prime * 3

assert not np.allclose(a.numpy(), a_prime.numpy())
assert not np.allclose(b.numpy(), b_prime.numpy())

추적 순서로 인한 디버깅 변형을 허용하기 위해 num_random_ops 모드의 DeterministicRandomTestTool을 사용하여 operation_seed 속성으로 추적한 임의 연산의 수를 확인할 수 있습니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  print(random_tool.operation_seed)
  a = tf.random.uniform(shape=(3,1))
  a = a * 3
  print(random_tool.operation_seed)
  b = tf.random.uniform(shape=(3,3))
  b = b * 3
  print(random_tool.operation_seed)
0
1
2

테스트에서 다양한 추적 순서를 고려해야 하는 경우 자동 증가 operation_seed를 명시적으로 설정할 수도 있습니다. 예를 들어, 이를 사용하여 서로 다른 두 프로그램 순서에서 난수 생성이 일치하도록 할 수 있습니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  print(random_tool.operation_seed)
  a = tf.random.uniform(shape=(3,1))
  a = a * 3
  print(random_tool.operation_seed)
  b = tf.random.uniform(shape=(3,3))
  b = b * 3

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  random_tool.operation_seed = 1
  b_prime = tf.random.uniform(shape=(3,3))
  b_prime = b_prime * 3
  random_tool.operation_seed = 0
  a_prime = tf.random.uniform(shape=(3,1))
  a_prime = a_prime * 3

np.testing.assert_allclose(a.numpy(), a_prime.numpy(), **tol_dict)
np.testing.assert_allclose(b.numpy(), b_prime.numpy(), **tol_dict)
0
1

다만, DeterministicRandomTestTool은 이미 사용한 연산 시드의 재사용을 허용하지 않으므로 자동 증가된 시퀀스를 겹치지 않도록 해야 합니다. 이는 즉시 실행이 동일한 연산 시드를 연속 사용할 경우 다른 숫자를 생성하는 반면 TF1 그래프 및 세션은 그렇지 않기 때문에, 오류를 발생시키면 세션 및 Eager 상태 저장 난수 생성을 유지하는 데 도움이 됩니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  random_tool.operation_seed = 1
  b_prime = tf.random.uniform(shape=(3,3))
  b_prime = b_prime * 3
  random_tool.operation_seed = 0
  a_prime = tf.random.uniform(shape=(3,1))
  a_prime = a_prime * 3
  try:
    c = tf.random.uniform(shape=(3,1))
    raise RuntimeError("An exception should have been raised before this, " +
                     "because the auto-incremented operation seed will " +
                     "overlap an already-used value")
  except ValueError as err:
    print(err)
This `DeterministicRandomTestTool` object is trying to re-use the already-used operation seed 1. It cannot guarantee random numbers will match between eager and sessions when an operation seed is reused. You most likely set `operation_seed` explicitly but used a value that caused the naturally-incrementing operation seed sequences to overlap with an already-used seed.

추론 확인

이제 DeterministicRandomTestTool을 사용하여 임의 가중치 초기화를 사용하는 경우에도 InceptionResnetV2 모델이 추론에서 일치하는지 확인할 수 있습니다. 일치하는 프로그램 순서로 인해 더 강력해진 테스트 조건에는 num_random_ops 모드를 사용합니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  graph = tf.Graph()
  with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
    height, width = 299, 299
    num_classes = 1000
    inputs = tf.ones( (1, height, width, 3))

    out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)

    # Initialize the variables
    sess.run(tf.compat.v1.global_variables_initializer())

    # Grab the outputs & regularization loss
    reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
    tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
    tf1_output = sess.run(out)

  print("Regularization loss:", tf1_regularization_loss)
Regularization loss: 1.2254326
height, width = 299, 299
num_classes = 1000

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  model = InceptionResnetV2(num_classes)

  inputs = tf.ones((1, height, width, 3))
  tf2_output, endpoints = model(inputs, training=False)

  # Grab the regularization loss as well
  tf2_regularization_loss = tf.math.add_n(model.losses)

print("Regularization loss:", tf2_regularization_loss)
Regularization loss: tf.Tensor(1.2254325, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool:
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)

훈련 확인하기

DeterministicRandomTestTool모든 상태 저장 임의 연산(가중치 초기화 및 드롭아웃 레이어와 같은 계산 모두 포함)에서 작동하므로 이를 사용하여 훈련 모드에서도 모델이 일치하는지 확인할 수 있습니다. 상태 저장 연산의 프로그램 순서가 일치하기 때문에 num_random_ops 모드를 다시 사용할 수 있습니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  graph = tf.Graph()
  with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
    height, width = 299, 299
    num_classes = 1000
    inputs = tf.ones( (1, height, width, 3))

    out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)

    # Initialize the variables
    sess.run(tf.compat.v1.global_variables_initializer())

    # Grab the outputs & regularization loss
    reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
    tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
    tf1_output = sess.run(out)

  print("Regularization loss:", tf1_regularization_loss)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py:581: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
Regularization loss: 1.22548
height, width = 299, 299
num_classes = 1000

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  model = InceptionResnetV2(num_classes)

  inputs = tf.ones((1, height, width, 3))
  tf2_output, endpoints = model(inputs, training=True)

  # Grab the regularization loss as well
  tf2_regularization_loss = tf.math.add_n(model.losses)

print("Regularization loss:", tf2_regularization_loss)
Regularization loss: tf.Tensor(1.2254798, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)

이제 tf.keras.layers.Layer 주위의 데코레이터로 즉시 실행되는 InceptionResnetV2 모델이 TF1 그래프 및 세션에서 실행되는 슬림 네트워크와 수치적으로 일치함을 확인했습니다.

참고: num_random_ops 모드에서 DeterministicRandomTestTool을 사용할 때 수치적 등가성을 테스트할 경우 tf.keras.layers.Layer 메서드 데코레이터를 직접 사용하고 호출하는 것이 좋습니다. Keras 함수형 모델 또는 다른 Keras 모델에 임베딩하면 TF1.x 그래프/세션과 즉시 실행을 비교할 때 추론하거나 정확히 일치시키기가 까다로울 수 있는 상태 저장 임의 연산 추적 순서에서 차이가 발생할 수 있습니다.

예를 들어, training=True로 직접 InceptionResnetV2 레이어를 호출하면 네트워크 생성 순서에 따라 드롭아웃 순서로 변수 초기화가 인터리브됩니다.

반면에, 먼저 tf.keras.layers.Layer 데코레이터를 Keras 함수형 모델에 넣은 다음 training=True로 모델을 호출하는 것은 모든 변수를 초기화한 후 드롭아웃 레이어를 사용하는 것과 같습니다. 이렇게 하면 다른 추적 순서와 다른 난수 세트가 생성됩니다.

그런데 기본 mode='constant'는 추적 순서의 이러한 차이에 민감하지 않으며 Keras 함수형 모델에 레이어를 임베딩할 때에도 추가 작업 없이 통과합니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
  graph = tf.Graph()
  with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
    height, width = 299, 299
    num_classes = 1000
    inputs = tf.ones( (1, height, width, 3))

    out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)

    # Initialize the variables
    sess.run(tf.compat.v1.global_variables_initializer())

    # Get the outputs & regularization losses
    reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
    tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
    tf1_output = sess.run(out)

  print("Regularization loss:", tf1_regularization_loss)
Regularization loss: 1.2239965
height, width = 299, 299
num_classes = 1000

random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
  keras_input = tf.keras.Input(shape=(height, width, 3))
  layer = InceptionResnetV2(num_classes)
  model = tf.keras.Model(inputs=keras_input, outputs=layer(keras_input))

  inputs = tf.ones((1, height, width, 3))
  tf2_output, endpoints = model(inputs, training=True)

  # Get the regularization loss
  tf2_regularization_loss = tf.math.add_n(model.losses)

print("Regularization loss:", tf2_regularization_loss)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`layer.updates` will be removed in a future version. '
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
Regularization loss: tf.Tensor(1.2239964, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)

3b 또는 4b단계(선택 사항): 기존 체크포인트로 테스트하기

위에서 3단계 또는 4단계 후에, 기존 이름 기반 체크포인트가 있는 경우 시작할 때 수치 등가성 테스트를 수행하면 도움이 될 수 있습니다. 이를 통해 레거시 체크포인트 로딩이 올바르게 작동하고 모델 자체가 올바르게 작동하는지 테스트할 수 있습니다. TF1.x 체크포인트 재사용 가이드에서는 기존 TF1.x 체크포인트를 재사용하고 이를 TF2 체크포인트로 전송하는 방법을 다룹니다.

추가 테스트 및 문제 해결

더 많은 수치적 등가성 테스트를 추가하면 그래디언트 계산(또는 옵티마이저 프로그램 업데이트)이 일치하는지 확인하는 테스트를 추가하도록 선택할 수도 있습니다.

역전파 및 그래디언트 계산은 모델 순방향 전달보다 부동 소수점 수치적 불안정성에 더 취약합니다. 이것은 등가성 테스트가 훈련에서 고립되지 않은 부분을 더 많이 다루기 때문에 전체를 즉시 실행하는 것과 TF1 그래프 사이에 비자명한 수치적 차이를 보기 시작할 수 있음을 의미합니다. TensorFlow의 그래프 옵티마이저가 더 적은 수의 수학적 연산으로 그래프의 하위 표현식을 대체하는 등의 작업을 수행하기 때문에 이러한 현상이 나타날 수 있습니다.

이것의 실현 가능성을 구분하기 위해 TF1 코드를 순수한 즉시 계산이 아닌 tf.function 내부(TF1 그래프와 같은 그래프 최적화 패스를 적용)에서 발생하는 TF2 계산과 비교할 수 있습니다. 또는, tf.config.optimizer.set_experimental_options를 사용하여 TF1 계산 전에 "arithmetic_optimization"와 같은 최적화 패스를 비활성화하여 결과가 수치적으로 TF2 계산 결과에 가깝게 나오는지 확인할 수 있습니다. 실제 훈련 실행에서는 성능상의 이유로 최적화 패스가 활성화된 상태에서 tf.function을 사용하는 것이 좋지만 수치 등가성 단위 테스트에서는 비활성화하는 것이 유용할 수 있습니다.

마찬가지로 tf.compat.v1.train 옵티마이저와 TF2 옵티마이저는 그들이 나타내는 수학 공식은 동일하더라도 TF2 옵티마이저와 부동 소수점 숫자 속성이 약간 다를 수 있습니다. 훈련 실행에서는 이렇게 되어도 문제가 될 가능성이 적지만 등가성 단위 테스트에서는 더 높은 수치 허용 오차가 필요할 수 있습니다.