TF2 워크플로에서 TF1.x 모델 사용하기

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 보기 노트북 다운로드하기

이 가이드는 즉시 실행, tf.function, 모델링 코드 변경을 최소화하는 배포 전략과 같이 TF2 워크플로에서 기존 TF1.x 모델을 사용하는 데 사용할 수 있는 모델링 코드 shim의 개요와 예제를 제공합니다.

사용 범위

이 가이드에 설명된 shim은 다음에 의존하는 TF1.x 모델에서 사용할 수 있게 설계되어 있습니다.

  1. 변수 생성 및 재사용을 제어하는 tf.compat.v1.get_variabletf.compat.v1.variable_scope
  2. 가중치 및 정규화 손실을 추적하는 tf.compat.v1.global_variables(), tf.compat.v1.trainable_variables, tf.compat.v1.losses.get_regularization_losses()tf.compat.v1.get_collection()

여기에는 tf.compat.v1.layer, tf.contrib.layers API, TensorFlow-Slim을 기반으로 구축된 대부분의 모델이 포함됩니다.

다음 TF1.x 모델에는 shim이 필요하지 않습니다.

  1. model.trainable_weightsmodel.losses를 통해 이미 훈련할 수 있는 모든 가중치와 정규화 손실을 추적하는 독립형 Keras 모델
  2. module.trainable_variables를 통해 훈련할 수 있는 모든 가중치를 이미 추적하고 아직 생성되지 않은 경우에만 가중치를 생성하는 tf.Module

이러한 모델들은 즉시 실행 및 tf.function을 사용하여 TF2에서 작동할 가능성이 높습니다.

설치하기

TensorFlow 및 기타 종속성을 가져옵니다.

pip uninstall -y -q tensorflow
# Install tf-nightly as the DeterministicRandomTestTool is available only in
# Tensorflow 2.8

pip install -q tf-nightly
import tensorflow as tf
import tensorflow.compat.v1 as v1
import sys
import numpy as np

from contextlib import contextmanager
2022-12-14 20:38:31.956458: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay

track_tf1_style_variables 데코레이터

이 가이드에서 설명하는 주요 shim은 TF1.x-style 가중치를 추적하고 정규화 손실을 캡처하는 tf.keras.layers.Layertf.Module에 속한 메서드 내에서 사용할 수 있는 데코레이터인 tf.compat.v1.keras.utils.track_tf1_style_variables입니다.

tf.compat.v1.keras.utils.track_tf1_style_variables를 사용하여 tf.keras.layers.Layer 또는 tf.Module의 호출 메서드를 데코레이팅하면 호출될 때마다 항상 새 변수를 생성하는 대신 데코레이팅된 메서드 내에서 올바르게 작동하는 tf.compat.v1.get_variable(및 확장자 tf.compat.v1.layers)를 통해 변수를 생성하고 재사용할 수 있습니다. 또한 레이어 또는 모듈이 데코레이팅된 메서드 내에서 get_variable을 통해 생성되거나 액세스된 가중치를 암시적으로 추적하도록 합니다.

표준 layer.variable/module.variable/etc 속성에서 가중치 자체를 추적하는 것 외에도 메서드가 tf.keras.layers.Layer에 속하면 get_variable 또는 tf.compat.v1.layers 정규화 인수를 통해 지정된 정규화 손실은 표준 layer.losses 속성 아래의 레이어에서 추적됩니다.

이 추적 메커니즘을 사용하면 TF2 작동이 사용 설정된 경우에도 Keras 레이어 또는 TF2의 tf.Module 내부에서 TF1.x-style 모델 순방향 전달 코드의 큰 클래스를 사용할 수 있습니다.

사용 예제

아래의 사용 예제는 tf.keras.layers.Layer 메서드를 데코레이팅할 때 사용하는 모델링 shim을 보여주지만, Keras 특성과 구체적으로 상호작용하는 경우를 제외하고 tf.Module을 데코레이팅할 때에도 적용할 수 있습니다.

tf.compat.v1.get_variable를 사용하여 빌드한 레이어

다음과 같이 tf.compat.v1.get_variable를 기반으로 직접 구현한 레이어가 있다고 상상해 보겠습니다.

def dense(self, inputs, units):
  out = inputs
  with tf.compat.v1.variable_scope("dense"):
    # The weights are created with a `regularizer`,
    kernel = tf.compat.v1.get_variable(
        shape=[out.shape[-1], units],
        regularizer=tf.keras.regularizers.L2(),
        initializer=tf.compat.v1.initializers.glorot_normal,
        name="kernel")
    bias = tf.compat.v1.get_variable(
        shape=[units,],
        initializer=tf.compat.v1.initializers.zeros,
        name="bias")
    out = tf.linalg.matmul(out, kernel)
    out = tf.compat.v1.nn.bias_add(out, bias)
  return out

Shim을 레이어로 변환한 후 입력에서 호출합니다.

class DenseLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    out = inputs
    with tf.compat.v1.variable_scope("dense"):
      # The weights are created with a `regularizer`,
      # so the layer should track their regularization losses
      kernel = tf.compat.v1.get_variable(
          shape=[out.shape[-1], self.units],
          regularizer=tf.keras.regularizers.L2(),
          initializer=tf.compat.v1.initializers.glorot_normal,
          name="kernel")
      bias = tf.compat.v1.get_variable(
          shape=[self.units,],
          initializer=tf.compat.v1.initializers.zeros,
          name="bias")
      out = tf.linalg.matmul(out, kernel)
      out = tf.compat.v1.nn.bias_add(out, bias)
    return out

layer = DenseLayer(10)
x = tf.random.normal(shape=(8, 20))
layer(x)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_46200/795621215.py:7: The name tf.keras.utils.track_tf1_style_variables is deprecated. Please use tf.compat.v1.keras.utils.track_tf1_style_variables instead.
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[-1.0201198 , -0.951724  ,  1.2094985 , -0.2987774 ,  0.31393754,
         1.35114   ,  0.6462465 ,  0.8913398 ,  0.75100744, -1.2132128 ],
       [-0.68590415, -0.26175082,  0.64078116, -2.912399  , -2.222181  ,
        -2.292143  , -2.0610387 , -1.4036045 , -0.43879175, -0.69157904],
       [ 0.13097352, -0.49666354, -0.122164  , -1.4062216 ,  0.42534038,
         0.20121956,  0.17584828,  0.37265995, -0.05870482, -0.30198455],
       [-1.3509794 , -1.1531386 ,  0.08235171, -1.276886  ,  0.4874733 ,
         0.12147932,  1.0546646 , -0.5473106 , -2.1247218 , -0.77802217],
       [-0.21143234, -0.7500431 ,  1.2885294 , -1.0960779 ,  1.1789135 ,
        -0.22484559, -2.3605824 , -1.1531962 ,  0.9950639 , -0.34410197],
       [ 3.238175  , -0.396873  ,  0.27031243,  1.3871925 ,  0.49264675,
         0.04602268,  0.2495502 ,  0.12468082,  0.7785794 , -0.13779987],
       [ 0.15282826, -0.47604153, -0.6094171 ,  1.450929  , -0.10152841,
         0.26042965,  0.6113905 , -0.8555389 , -0.43506438,  0.45744914],
       [-1.3396955 , -0.9795493 ,  0.3935601 ,  1.34605   ,  0.7668753 ,
         0.92987084,  2.042585  , -0.40245664, -0.85167456,  1.0894947 ]],
      dtype=float32)>

표준 Keras 레이어와 같은 추적된 변수와 캡처된 정규화 손실에 액세스합니다.

layer.trainable_variables
layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.14263782>]

레이어를 호출할 때마다 가중치가 재사용되는지 확인하려면 모든 가중치를 0으로 설정하고 레이어를 다시 호출합니다.

print("Resetting variables to zero:", [var.name for var in layer.trainable_variables])

for var in layer.trainable_variables:
  var.assign(var * 0.0)

# Note: layer.losses is not a live view and
# will get reset only at each layer call
print("layer.losses:", layer.losses)
print("calling layer again.")
out = layer(x)
print("layer.losses: ", layer.losses)
out
Resetting variables to zero: ['dense/bias:0', 'dense/kernel:0']
layer.losses: [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
calling layer again.
layer.losses:  [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

Keras 함수 모델 구성에서도 변환한 레이어를 직접 사용할 수 있습니다.

inputs = tf.keras.Input(shape=(20))
outputs = DenseLayer(10)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

x = tf.random.normal(shape=(8, 20))
model(x)

# Access the model variables and regularization losses
model.weights
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.13125566>]

tf.compat.v1.layers로 빌드한 모델

다음과 같이 tf.compat.v1.layers를 기반으로 직접 구현한 레이어가 있다고 상상해 보겠습니다.

def model(self, inputs, units):
  with tf.compat.v1.variable_scope('model'):
    out = tf.compat.v1.layers.conv2d(
        inputs, 3, 3,
        kernel_regularizer="l2")
    out = tf.compat.v1.layers.flatten(out)
    out = tf.compat.v1.layers.dense(
        out, units,
        kernel_regularizer="l2")
    return out

Shim을 레이어로 변환한 후 입력에서 호출합니다.

class CompatV1LayerModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

layer = CompatV1LayerModel(10)
x = tf.random.normal(shape=(8, 5, 5, 5))
layer(x)
/tmpfs/tmp/ipykernel_46200/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_46200/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_46200/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[-0.06327415,  2.092562  ,  0.57418495, -0.8743783 , -1.1326195 ,
        -1.8447869 ,  0.04119575, -1.7533609 ,  0.8054693 ,  0.44668588],
       [-0.8397758 , -1.8044966 , -3.0734472 ,  1.2611016 ,  2.8427548 ,
        -1.17187   , -0.1264826 , -0.8966931 , -1.3744891 , -0.4927774 ],
       [ 1.0648903 , -1.3715969 ,  1.1525033 ,  1.086575  ,  0.38834083,
         2.0468204 ,  1.4457653 ,  0.17835286, -0.8083112 , -1.7415471 ],
       [ 1.0040369 ,  4.5253763 , -0.8186129 , -0.02587175,  0.4320113 ,
         2.3105392 ,  0.38452315, -3.5626667 ,  0.06980598, -0.08916754],
       [-1.4197118 ,  0.5707685 , -0.8055949 ,  0.9788049 , -0.81847227,
        -0.6336094 ,  0.5286823 , -0.41103727,  0.08750917, -0.10323423],
       [ 1.8055147 , -0.29969662,  0.09428084,  2.7879486 ,  2.4568238 ,
         2.5910087 , -0.5912935 ,  2.662509  , -0.39896762, -0.89899695],
       [-0.8941417 ,  0.9635209 , -1.5840545 , -1.0645733 , -1.7501976 ,
        -2.8524106 ,  0.50043744, -1.5798345 , -2.3090618 ,  1.0171978 ],
       [ 0.16678604, -1.6925997 ,  0.23552456, -0.22428381, -1.9627607 ,
        -1.7819912 ,  0.01840904,  0.8170755 , -2.5374982 , -0.22768609]],
      dtype=float32)>

경고: 안전상의 이유로 모든 tf.compat.v1.layers를 비어 있지 않은 문자열 variable_scope 안에 넣어야 합니다. 자동 생성된 이름이 있는 tf.compat.v1.layers는 항상 변수 범위 밖에서 이름을 자동으로 늘리기 때문입니다. 즉, 레이어/모듈을 호출할 때마다 요청된 변수 이름이 일치하지 않습니다. 따라서 이미 만든 가중치를 재사용하는 대신 호출할 때마다 새로운 변수 세트를 생성합니다.

표준 Keras 레이어와 같은 추적된 변수와 캡처된 정규화 손실에 액세스합니다.

layer.trainable_variables
layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.04080253>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.14297499>]

레이어를 호출할 때마다 가중치가 재사용되는지 확인하려면 모든 가중치를 0으로 설정하고 레이어를 다시 호출합니다.

print("Resetting variables to zero:", [var.name for var in layer.trainable_variables])

for var in layer.trainable_variables:
  var.assign(var * 0.0)

out = layer(x)
print("layer.losses: ", layer.losses)
out
Resetting variables to zero: ['model/conv2d/bias:0', 'model/conv2d/kernel:0', 'model/dense/bias:0', 'model/dense/kernel:0']
layer.losses:  [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
/tmpfs/tmp/ipykernel_46200/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_46200/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_46200/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

Keras 함수 모델 구성에서도 변환한 레이어를 직접 사용할 수 있습니다.

inputs = tf.keras.Input(shape=(5, 5, 5))
outputs = CompatV1LayerModel(10)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

x = tf.random.normal(shape=(8, 5, 5, 5))
model(x)
/tmpfs/tmp/ipykernel_46200/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
/tmpfs/tmp/ipykernel_46200/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_46200/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[-0.10423124, -0.49252534, -0.10436222,  0.068404  , -0.9778502 ,
         1.3258257 ,  1.9239843 , -0.5209205 , -0.08098727,  1.3616567 ],
       [ 2.008305  , -0.63370496, -1.4405237 ,  0.4341396 , -1.971505  ,
        -0.16152357,  0.8991833 ,  1.0047848 , -1.751168  , -0.4906541 ],
       [ 2.6769912 , -0.5353887 , -1.0216954 ,  0.545533  , -0.5864865 ,
         0.6349678 ,  1.0974748 , -0.3922951 , -1.3166703 , -0.62223816],
       [-0.8650606 ,  0.7834866 , -0.77969337, -1.1913569 , -0.7170765 ,
        -1.4810982 ,  1.5137848 ,  0.6127273 , -0.9004924 ,  3.4849586 ],
       [-1.235096  , -0.35509142,  1.5101688 ,  1.4356166 , -0.60496473,
         0.992941  , -0.24767148,  0.36325622,  0.43800712, -1.8807107 ],
       [-1.0707847 ,  1.3324004 ,  2.0317771 , -0.5796515 , -1.2188568 ,
        -2.9009778 ,  1.8013899 ,  1.3486978 ,  0.1742219 ,  0.5256916 ],
       [ 1.9507825 , -0.35828122, -2.191814  ,  0.8309845 , -1.41339   ,
        -0.5650949 , -0.36482388,  0.14574707, -1.3613335 ,  0.99965405],
       [ 1.09396   , -0.8510579 ,  0.09120522,  2.5301988 , -0.98921394,
         0.48626742,  2.064007  ,  0.18146089,  1.0849928 , -1.0466281 ]],
      dtype=float32)>
# Access the model variables and regularization losses
model.weights
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.03315928>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.14079392>]

배치 정규화 업데이트 및 모델 training 인수 캡처하기

TF1.x에서는 다음과 같이 배치 정규화를 수행합니다.

  x_norm = tf.compat.v1.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)
  train_op = optimizer.minimize(loss)
  train_op = tf.group([train_op, update_ops])

참고:

  1. 배치 정규화 이동 평균 업데이트는 레이어와는 별개로 호출된 get_collection에 의해 추적됩니다.
  2. tf.compat.v1.layers.batch_normalization에는 training 인수가 필요합니다(TF-Slim 배치 정규화 레이어를 사용하는 경우 일반적으로 is_training이라고 함).

TF2에서는 즉시 실행과 자동 제어 종속성으로 인해 배치 정규화 이동 평균 업데이트가 즉시 실행됩니다. 업데이트 컬렉션에서 별도로 수집하고 명시적 제어 종속성으로 추가할 필요가 없습니다.

또한 tf.keras.layers.Layer의 순방향 전달 메서드에 training 인수를 지정하면 Keras는 다른 레이어에서 하는 것처럼 현재 훈련 단계와 모든 중첩 레이어를 전달할 수 있게 됩니다. Keras가 training 인수를 처리하는 방법에 대한 자세한 정보는 tf.keras.Model용 API 문서를 참조하세요.

tf.Module 메서드를 데코레이팅하는 경우에는 필요에 따라 모든 training 인수를 수동으로 전달해야 합니다. 그러나 배치 정규화 이동 평균 업데이트는 명시적인 제어 종속성이 없어도 여전히 자동으로 적용됩니다.

다음 코드 조각은 shim에 배치 정규화 레이어를 삽입하는 방법과 Keras 모델에서 이를 사용하는 방법을 보여줍니다(tf.keras.layers.Layer에 적용 가능).

class CompatV1BatchNorm(tf.keras.layers.Layer):

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    print("Forward pass called with `training` =", training)
    with v1.variable_scope('batch_norm_layer'):
      return v1.layers.batch_normalization(x, training=training)
print("Constructing model")
inputs = tf.keras.Input(shape=(5, 5, 5))
outputs = CompatV1BatchNorm()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

print("Calling model in inference mode")
x = tf.random.normal(shape=(8, 5, 5, 5))
model(x, training=False)

print("Moving average variables before training: ",
      {var.name: var.read_value() for var in model.non_trainable_variables})

# Notice that when running TF2 and eager execution, the batchnorm layer directly
# updates the moving averages while training without needing any extra control
# dependencies
print("calling model in training mode")
model(x, training=True)

print("Moving average variables after training: ",
      {var.name: var.read_value() for var in model.non_trainable_variables})
Constructing model
Forward pass called with `training` = None
/tmpfs/tmp/ipykernel_46200/3053504896.py:7: UserWarning: `tf.layers.batch_normalization` is deprecated and will be removed in a future version. Please use `tf.keras.layers.BatchNormalization` instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.BatchNormalization` documentation).
  return v1.layers.batch_normalization(x, training=training)
Calling model in inference mode
Forward pass called with `training` = False
Moving average variables before training:  {'batch_norm_layer/batch_normalization/moving_mean:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=array([0., 0., 0., 0., 0.], dtype=float32)>, 'batch_norm_layer/batch_normalization/moving_variance:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=array([1., 1., 1., 1., 1.], dtype=float32)>}
calling model in training mode
Forward pass called with `training` = True
Moving average variables after training:  {'batch_norm_layer/batch_normalization/moving_mean:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=
array([ 2.6681734e-04,  9.5353404e-05,  1.0631174e-03, -2.0060694e-04,
       -2.9556750e-06], dtype=float32)>, 'batch_norm_layer/batch_normalization/moving_variance:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=
array([1.0000558 , 0.99979514, 1.0010005 , 1.0000939 , 0.9999507 ],
      dtype=float32)>}

변수 범위 기반 변수 재사용하기

get_variable을 기반으로 하는 순방향 전달에서 생성한 모든 변수는 TF1.x의 변수 범위와 동일한 변수 이름 지정과 재사용 의미 체계를 유지합니다. 위에서 언급한 것처럼 자동 생성한 이름을 가진 모든 tf.compat.v1.layers에 대해 비어 있지 않은 외부 범위가 하나 이상 갖는 한 참입니다.

참고: 이름 지정 및 재사용은 단일 레이어/모듈 인스턴스 내로 범위가 지정됩니다. 하나의 shim 데코레이션 레이어 또는 모듈 내에서 get_variable을 호출하면 레이어나 모듈 내에서 생성한 변수를 참조할 수 없습니다. 필요한 경우 get_variable을 통해 변수에 액세스하는 대신 Python 참조를 다른 변수에 직접 사용하여 이 문제를 해결할 수 있습니다.

즉시 실행과 tf.function

위에서 보았듯이 tf.keras.layers.Layertf.Module로 데코레이팅된 메서드는 즉시 실행 내부에서 실행되며 tf.function과도 호환됩니다. 즉, pdb 및 기타 대화형 도구를 사용하여 실행 중인 순방향 전달을 단계별로 실행할 수 있음을 의미합니다.

경고: tf.function내부에서 shim으로 데코레이트한 레이어/모듈 메서드를 호출하는 것은 완벽하게 안전하지만 tf.functionsget_variable 호출이 포함된 경우 tf.function를 넣는 것은 안전하지 않습니다. tf.function을 입력하면 variable_scope가 재설정됩니다. 즉, shim이 모방하는 TF1.x 스타일 변수 범위 기반의 변수 재사용이 이 설정에서 중단됩니다.

분산 전략

@track_tf1_style_variables로 데코레이트한 레이어 또는 모듈 메서드 안에서 이루어지는 get_variable 호출은 내부에서 표준 tf.Variable 변수 생성을 사용합니다. 즉, MirroredStrategyTPUStrategy와 같은 tf.distribute에서 사용할 수 있는 다양한 분산 전략과 함께 사용할 수 있습니다.

데코레이션 호출에서 tf.Variable, tf.Module, tf.keras.layerstf.keras.models 중첩하기

tf.compat.v1.keras.utils.track_tf1_style_variables에서 레이어 호출을 데코레이팅하면 tf.compat.v1.get_variable을 통해 생성(및 재사용)한 변수의 자동 암시 추적만 추가됩니다. 일반적인 Keras 레이어 및 대부분의 tf.Module에서 사용하는 것과 같이 tf.Variable 호출로 직접 생성한 가중치는 캡처하지 않습니다. 이 섹션에서는 이러한 중첩 사례를 처리하는 방법을 설명합니다.

(기존 사용법) tf.keras.layerstf.keras.models

중첩 Keras 레이어 및 모델의 기존 사용법에서는 tf.compat.v1.keras.utils.get_or_create_layer를 사용합니다. 이는 기존 TF1.x 중첩 Keras 사용을 쉽게 마이그레이션할 때만 권장됩니다. 새 코드는 tf.Variables 및 tf.Modules에 대해 아래에 설명한 대로 명시적 속성 설정을 사용해야 합니다.

tf.compat.v1.keras.utils.get_or_create_layer를 사용하려면 중첩 모델을 구성하는 코드를 메서드로 래핑한 후 메서드에 전달해야 합니다. 예제:

class NestedModel(tf.keras.Model):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  def build_model(self):
    inp = tf.keras.Input(shape=(5, 5))
    dense_layer = tf.keras.layers.Dense(
        10, name="dense", kernel_regularizer="l2",
        kernel_initializer=tf.compat.v1.ones_initializer())
    model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
    return model

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    # Get or create a nested model without assigning it as an explicit property
    model = tf.compat.v1.keras.utils.get_or_create_layer(
        "dense_model", self.build_model)
    return model(inputs)

layer = NestedModel(10)
layer(tf.ones(shape=(5,5)))
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.]], dtype=float32)>

이 메서드는 이러한 중첩 레이어가 TensorFlow에서 올바르게 재사용되고 추적되도록 합니다. @track_tf1_style_variables 데코레이터는 적절한 메서드에서 여전히 필요합니다. get_or_create_layer(이 경우 self.build_model)로 전달된 모델 빌더 메서드는 인수를 사용하면 안 됩니다.

가중치는 다음과 같이 추적합니다.

assert len(layer.weights) == 2
weights = {x.name: x for x in layer.variables}

assert set(weights.keys()) == {"dense/bias:0", "dense/kernel:0"}

layer.weights
[<tf.Variable 'dense/kernel:0' shape=(5, 10) dtype=float32, numpy=
 array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]

정규화 손실도 다음과 같습니다.

tf.add_n(layer.losses)
<tf.Tensor: shape=(), dtype=float32, numpy=0.5>

증분 마이그레이션: tf.Variablestf.Modules

데코레이팅한 메서드에 tf.Variable 호출 또는 tf.Module을 삽입해야 하는 경우(예: 이 가이드의 뒷부분에 설명된 레거시가 아닌 TF2 API로의 증분 마이그레이션을 따르는 경우) 다음 요구 사항에 따라 명시적으로 추적해야 합니다.

  • 변수/모듈/레이어가 한 번만 생성되었는지 명시적으로 확인
  • 일반 모듈 또는 레이어를 정의할 때와 마찬가지로 인스턴스 속성으로 명시적으로 연결
  • 후속 호출에서 이미 생성한 객체를 명시적으로 재사용

이렇게 하면 호출할 때마다 가중치를 새롭게 생성하지 않고 올바르게 재사용합니다. 또한 이를 통해 기존 가중치 및 정규화 손실을 추적할 수 있습니다.

다음은 이러한 작동 방식을 설명한 예제입니다.

class NestedLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def __call__(self, inputs):
    out = inputs
    with tf.compat.v1.variable_scope("inner_dense"):
      # The weights are created with a `regularizer`,
      # so the layer should track their regularization losses
      kernel = tf.compat.v1.get_variable(
          shape=[out.shape[-1], self.units],
          regularizer=tf.keras.regularizers.L2(),
          initializer=tf.compat.v1.initializers.glorot_normal,
          name="kernel")
      bias = tf.compat.v1.get_variable(
          shape=[self.units,],
          initializer=tf.compat.v1.initializers.zeros,
          name="bias")
      out = tf.linalg.matmul(out, kernel)
      out = tf.compat.v1.nn.bias_add(out, bias)
    return out

class WrappedDenseLayer(tf.keras.layers.Layer):

  def __init__(self, units, **kwargs):
    super().__init__(**kwargs)
    self.units = units
    # Only create the nested tf.variable/module/layer/model
    # once, and then reuse it each time!
    self._dense_layer = NestedLayer(self.units)

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('outer'):
      outputs = tf.compat.v1.layers.dense(inputs, 3)
      outputs = tf.compat.v1.layers.dense(inputs, 4)
      return self._dense_layer(outputs)

layer = WrappedDenseLayer(10)

layer(tf.ones(shape=(5, 5)))
/tmpfs/tmp/ipykernel_46200/2765428776.py:38: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  outputs = tf.compat.v1.layers.dense(inputs, 3)
/tmpfs/tmp/ipykernel_46200/2765428776.py:39: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  outputs = tf.compat.v1.layers.dense(inputs, 4)
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[ 0.02090518,  0.33420768, -0.25673977,  0.3621331 ,  0.6460784 ,
         0.5631576 , -0.26076365, -0.11238655,  0.57047373,  0.31484437],
       [ 0.02090518,  0.33420768, -0.25673977,  0.3621331 ,  0.6460784 ,
         0.5631576 , -0.26076365, -0.11238655,  0.57047373,  0.31484437],
       [ 0.02090518,  0.33420768, -0.25673977,  0.3621331 ,  0.6460784 ,
         0.5631576 , -0.26076365, -0.11238655,  0.57047373,  0.31484437],
       [ 0.02090518,  0.33420768, -0.25673977,  0.3621331 ,  0.6460784 ,
         0.5631576 , -0.26076365, -0.11238655,  0.57047373,  0.31484437],
       [ 0.02090518,  0.33420768, -0.25673977,  0.3621331 ,  0.6460784 ,
         0.5631576 , -0.26076365, -0.11238655,  0.57047373,  0.31484437]],
      dtype=float32)>

track_tf1_style_variables 데코레이터로 데코레이팅한 경우에도 중첩 모듈을 명시적으로 추적해야 합니다. 이는 데코레이팅한 메서드가 있는 각 모듈/레이어에 연결된 자체 변수 저장소가 있기 때문입니다.

다음과 같은 경우 가중치를 올바르게 추적합니다.

assert len(layer.weights) == 6
weights = {x.name: x for x in layer.variables}

assert set(weights.keys()) == {"outer/inner_dense/bias:0",
                               "outer/inner_dense/kernel:0",
                               "outer/dense/bias:0",
                               "outer/dense/kernel:0",
                               "outer/dense_1/bias:0",
                               "outer/dense_1/kernel:0"}

layer.trainable_weights
[<tf.Variable 'outer/inner_dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/inner_dense/kernel:0' shape=(4, 10) dtype=float32, numpy=
 array([[ 0.37943438,  0.53132665, -0.08674298, -0.23320225, -0.13015583,
         -0.4591074 ,  0.25781223,  0.36397132,  0.24894506, -0.45869645],
        [ 0.04810307, -0.5566843 , -0.32545373, -0.56785744, -0.30109942,
          0.03485163,  0.18217006, -0.38029072,  0.31521076,  0.02151082],
        [-0.00936828, -0.25193396,  0.09320661, -0.24430814, -0.31924888,
         -0.5837658 ,  0.11786742,  0.61163   , -0.4844167 ,  0.14803399],
        [ 0.08637422, -0.4133725 , -0.63585305, -0.40947035,  0.19987728,
          0.33251715, -0.00949872, -0.20081694,  0.7561279 ,  0.47479796]],
       dtype=float32)>,
 <tf.Variable 'outer/dense/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/dense/kernel:0' shape=(5, 3) dtype=float32, numpy=
 array([[-0.85065323, -0.0477919 ,  0.69770926],
        [-0.27828598,  0.55636925, -0.6969459 ],
        [ 0.51618105, -0.63780004, -0.6297898 ],
        [-0.21282446,  0.7884807 , -0.2784927 ],
        [ 0.1040144 ,  0.7652603 ,  0.56572074]], dtype=float32)>,
 <tf.Variable 'outer/dense_1/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/dense_1/kernel:0' shape=(5, 4) dtype=float32, numpy=
 array([[ 0.47555745, -0.33051252, -0.01657987,  0.31172693],
        [-0.13664114,  0.05349594, -0.60820794, -0.67012715],
        [ 0.4469068 , -0.7925951 , -0.6421313 ,  0.5783521 ],
        [-0.15768999,  0.21155965,  0.08365291, -0.1433317 ],
        [-0.6506794 , -0.15860069,  0.66006434,  0.7738956 ]],
       dtype=float32)>]

정규화 손실도 마찬가지입니다.

layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.052549247>]

NestedLayer가 Keras가 아닌 tf.Module인 경우 변수는 계속 추적되지만 정규화 손실은 자동으로 추적되지 않으므로 명시적으로 따로따로 추적해야 할 수 있습니다.

변수 이름 안내

명시적 tf.Variable 호출 및 Keras 레이어는 get_variablevariable_scopes의 조합에서 사용하는 것과는 다른 레이어 이름/변수 이름 자동 생성 메커니즘을 사용합니다. shim은 TF1.x 그래프에서 TF2 즉시 실행 및 tf.function으로 이동하는 경우에도 get_variable으로 생성한 변수 이름이 일치되게 만들지만, tf.Variable 호출과 메서드 데코레이터 내 임베딩한 Keras 레이어를 대상으로 생성한 변수 이름과의 일치 여부는 보장할 수 없습니다. 여러 변수가 TF2 즉시 실행 및 tf.function에서 동일한 이름을 공유할 수도 있습니다.

이 가이드의 뒷부분에서 정확성 검증 및 TF1.x 체크포인트 매핑에 대한 섹션을 진행할 때 특히 주의해야 합니다.

데코레이팅된 메서드에서 tf.compat.v1.make_template 사용하기

tf.compat.v1.make_template을 사용하는 대신 TF2에서 레이어가 더 얇은 tf.compat.v1.keras.utils.track_tf1_style_variables를 직접 사용하는 것을 권장합니다..

이미 tf.compat.v1.make_template에 의존하고 있던 이전 TF1.x 코드의 내용은 이 섹션의 안내를 따르세요.

tf.compat.v1.make_templateget_variable을 사용하는 코드를 래핑하므로 track_tf1_style_variables 데코레이터를 사용하면 레이어 호출에서 이러한 템플릿을 사용하고 가중치 및 정규화 손실을 추적할 수 있습니다.

단, make_template을 한 번만 호출한 다음 각 레이어 호출에서 동일한 템플릿을 재사용해야 합니다. 그렇지 않으면 새 변수 세트로 레이어를 호출할 때마다 새 템플릿이 생성됩니다.

예제는 다음과 같습니다.

class CompatV1TemplateScaleByY(tf.keras.layers.Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    def my_op(x, scalar_name):
      var1 = tf.compat.v1.get_variable(scalar_name,
                            shape=[],
                            regularizer=tf.compat.v1.keras.regularizers.L2(),
                            initializer=tf.compat.v1.constant_initializer(1.5))
      return x * var1
    self.scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op, scalar_name='y')

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('layer'):
      # Using a scope ensures the `scale_by_y` name will not be incremented
      # for each instantiation of the layer.
      return self.scale_by_y(inputs)

layer = CompatV1TemplateScaleByY()

out = layer(tf.ones(shape=(2, 3)))
print("weights:", layer.weights)
print("regularization loss:", layer.losses)
print("output:", out)
weights: [<tf.Variable 'layer/scale_by_y/y:0' shape=() dtype=float32, numpy=1.5>]
regularization loss: [<tf.Tensor: shape=(), dtype=float32, numpy=0.022499999>]
output: tf.Tensor(
[[1.5 1.5 1.5]
 [1.5 1.5 1.5]], shape=(2, 3), dtype=float32)

경고: shim 데코레이터의 변수 및 정규화 손실 추적 메커니즘이 손상될 수 있으므로 여러 레이어 인스턴스에서 동일한 make_template 생성 템플릿을 공유하지 않는 것이 좋습니다. 또한 여러 레이어 인스턴스 내에서 동일한 make_template 이름을 사용하려면 variable_scope 내에서 생성한 템플릿을 중첩해서 사용해야 합니다. 그렇지 않으면 템플릿의 variable_scope에 대해 생성한 이름이 레이어의 새 인스턴스마다 증가합니다. 이로 인해 예상치 못한 방식으로 가중치 이름이 변경될 수 있습니다.

네이티브 TF2로 증분 마이그레이션

앞서 언급했듯이 track_tf1_style_variables를 사용하면 TF2 스타일 객체 지향 tf.Variable/tf.keras.layers.Layer/tf.Module 사용을 동일한 데코레이션 모듈/레이어 내부에서 레거시 tf.compat.v1.get_variable/tf.compat.v1.layers 스타일 사용과 혼합할 수 있습니다.

즉, TF1.x 모델이 TF2와 완전히 호환되도록 만든 후 네이티브(tf.compat.v1이 아닌) TF2 API를 사용하여 모든 새 모델 구성 요소를 작성하고 이전 코드를 상호 운용할 수 있습니다.

다만, 이전 모델 구성 요소를 계속 수정하는 경우 레거시 스타일 tf.compat.v1 사용을 새로 작성한 TF2 코드에 권장되는 순수 네이티브 객체 지향 API로 점진적으로 전환하도록 선택할 수도 있습니다.

Keras 레이어/모델을 데코레이팅하는 경우에는 self.add_weight 호출로, Keras 객체 또는 tf.Module을 데코레이팅하는 경우 tf.Variable 호출로 tf.compat.v1.get_variable 사용을 교체할 수 있습니다.

함수형 스타일 및 객체 지향 tf.compat.v1.layers는 일반적으로 인수를 변경하지 않아도 동등한 tf.keras.layers 레이어로 교체할 수 있습니다.

또한 track_tf1_style_variables를 사용할 수 있는 순수 네이티브 API로 점진적으로 이동하는 동안 모델이나 일반 패턴의 청크 부분을 개별 레이어/모듈로 변경하는 방안을 고려할 수도 있습니다.

Slim 및 contrib.layers에 대한 노트

다수의 기존 TF 1.x 코드는 TF 1.x와 함께 tf.contrib.layers로 패키징된 Slim 라이브러리를 사용합니다. Slim을 사용하여 코드를 네이티브 TF 2로 변환하는 작업은 v1.layers를 변환하는 작업보다 더 복잡합니다. 사실 Slim 코드를 먼저 v1.layers로 변환한 다음 Keras로 변환하는 것이 합리적일 수 있습니다. 다음은 Slim 코드를 변환하는 몇 가지 일반적인 가이드입니다.

  • 모든 인수가 명시적인지 확인합니다. 가능한 경우 arg_scopes를 제거합니다. 계속 사용해야 하는 경우 normalizer_fnactivation_fn을 자체 레이어로 분할합니다.
  • 분리할 수 있는 전환 레이어는 하나 이상의 다른 Keras 레이어(깊이별, 포인트별 및 분리 가능한 Keras 레이어)에 매핑됩니다.
  • Slim과 v1.layers는 인수 이름과 기본값이 다릅니다.
  • 일부 인수는 다른 행렬을 사용합니다.

체크포인트 호환성을 무시하고 네이티브 TF2로 마이그레이션하기

다음 코드 샘플은 체크포인트 호환성을 고려하지 않고 모델을 순수 네이티브 API로 점진적으로 이동하는 방법을 보여줍니다.

class CompatModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dropout(out, training=training)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

그런 다음 compat.v1 API를 해당하는 네이티브 객체 지향 API로 부분적으로 교체합니다. 컨볼루션 레이어를 레이어 생성자에서 생성한 Keras 객체로 전환하여 시작합니다.

class PartiallyMigratedModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_layer(inputs)
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dropout(out, training=training)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

v1.keras.utils.DeterministicRandomTestTool 클래스를 사용하여 이 증분을 변경해도 모델이 이전과 동일하게 작동하는지 확인합니다.

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = CompatModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  original_output = layer(inputs)

  # Grab the regularization loss as well
  original_regularization_loss = tf.math.add_n(layer.losses)

print(original_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
/tmpfs/tmp/ipykernel_46200/355611412.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_46200/355611412.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_46200/355611412.py:14: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  out = tf.compat.v1.layers.dropout(out, training=training)
/tmpfs/tmp/ipykernel_46200/355611412.py:15: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = PartiallyMigratedModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
/tmpfs/tmp/ipykernel_46200/3237389364.py:14: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_46200/3237389364.py:15: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  out = tf.compat.v1.layers.dropout(out, training=training)
/tmpfs/tmp/ipykernel_46200/3237389364.py:16: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

이제 모든 개별 compat.v1.layers를 네이티브 Keras 레이어로 교체했습니다.

class NearlyFullyNativeModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")
    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_layer = tf.keras.layers.Dense(
      self.units,
      kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_layer(inputs)
      out = self.flatten_layer(out)
      out = self.dense_layer(out)
      return out
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = NearlyFullyNativeModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

마지막으로, 남아 있는(더 이상 필요하지 않은) variable_scope 사용과 track_tf1_style_variables 데코레이터를 모두 제거합니다.

이제 완전히 네이티브 API를 사용하는 모델 버전만 남았습니다.

class FullyNativeModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")
    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_layer = tf.keras.layers.Dense(
      self.units,
      kernel_regularizer="l2")

  def call(self, inputs):
    out = self.conv_layer(inputs)
    out = self.flatten_layer(out)
    out = self.dense_layer(out)
    return out
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = FullyNativeModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

네이티브 TF2로 마이그레이션하는 동안 체크포인트 호환성 유지하기

위의 네이티브 TF2 API로의 마이그레이션 프로세스는 변수 이름(Keras API가 매우 다른 가중치 이름을 생성하기 때문에)과 모델의 다른 가중치를 가리키는 객체 지향 경로를 모두 변경했습니다. 이러한 변경의 영향으로 기존 TF1 스타일 이름 기반 체크포인트 또는 TF2 스타일 객체 지향 체크포인트가 모두 망가질 수 있습니다.

그러나 경우에 따라 TF1.x 체크포인트 가이드 재사용하기에 자세히 설명된 것과 같은 접근 방식을 사용하여 원래 이름 기반 체크포인트를 사용하고 새 이름에 해당하는 변수 매핑을 찾을 수 있습니다.

이를 실현하는 몇 가지 팁은 다음과 같습니다.

  • 여전히 변수에는 모두 설정할 수 있는 name 인수가 있습니다.
  • 또한 Keras 모델은 변수의 접두사로 설정하는 name 인수를 사용합니다.
  • v1.name_scope 함수를 변수 이름의 접두어를 지정하는데 사용할 수 있습니다. 이 함수는 tf.variable_scope와는 매우 다릅니다. 이름에만 영향을 미치며 변수를 추적하거나 재사용을 관장하지 않습니다.

위의 포인트터를 기반으로 다음 코드 샘플은 체크포인트를 업데이트하는 동시에 코드에 적용하여 모델의 일부를 점진적으로 업데이트할 수 있는 워크플로를 보여줍니다.

참고: Keras 레이어의 변수 이름 지정은 복잡하기 때문에 일부 사용 사례에서 작동하지 않을 수 있습니다.

  1. 함수형 스타일 tf.compat.v1.layers를 객체 지향 버전으로 전환하는 작업으로 시작합니다.
class FunctionalStyleCompatModel(tf.keras.layers.Layer):

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.conv2d(
          out, 4, 4,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = FunctionalStyleCompatModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_46200/1716504801.py:6: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_46200/1716504801.py:9: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_46200/1716504801.py:12: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['model/conv2d/bias:0',
 'model/conv2d/kernel:0',
 'model/conv2d_1/bias:0',
 'model/conv2d_1/kernel:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']
  1. 그 다음에는 compat.v1.layer 객체와 compat.v1.get_variable로 생성한 모든 변수를 track_tf1_style_variables으로 메서드를 데코레이팅한 {tf.keras.layers.Layer/tf.Module 객체의 속성으로 할당합니다(모든 객체 지향 TF2 스타일 체크포인트는 이제 변수 이름을 사용한 경로와 새로운 객체 지향 경로를 모두 저장).
class OOStyleCompatModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.compat.v1.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.compat.v1.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      out = self.conv_2(out)
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = OOStyleCompatModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_46200/1693875107.py:17: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['model/conv2d/kernel:0',
 'model/conv2d/bias:0',
 'model/conv2d_1/kernel:0',
 'model/conv2d_1/bias:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']
  1. 이 시점에서 로드된 체크포인트를 다시 저장하여 변수 이름(compat.v1.layers의 경우) 또는 객체 지향 개체 그래프로 경로를 모두 저장합니다.
weights = {v.name: v for v in layer.weights}
assert weights['model/conv2d/kernel:0'] is layer.conv_1.kernel
assert weights['model/conv2d_1/bias:0'] is layer.conv_2.bias
  1. 이제 최근에 저장한 체크포인트를 계속 로드하는 동안에도 네이티브 Keras 레이어에서 객체 지향 compat.v1.layers를 교체할 수 있습니다. 교체된 레이어의 자동 생성된 variable_scopes를 계속 기록하여 남은 compat.v1.layers의 변수 이름을 보존해야 합니다. 이렇게 전환된 레이어/변수는 이제 변수 이름 경로 대신 체크포인트 변수의 객체 속성 경로만 사용합니다.

일반적으로 속성에 연결된 변수에서 compat.v1.get_variable의 사용을 다음과 같이 교체할 수 있습니다.

  • tf.Variable을 사용하도록 전환, 또는
  • tf.keras.layers.Layer.add_weight를 사용하여 업데이트. 한 번에 모든 레이어를 전환하지 않으면 name 인수가 누락된 남은 compat.v1.layers의 자동 생성된 레이어/변수 이름이 변경될 수 있습니다. 이 경우 제거된 compat.v1.layer의 생성된 범위 이름에 해당하는 variable_scope를 수동으로 열고 닫으며 나머지 compat.v1.layers의 변수 이름을 동일하게 유지해야 합니다. 그렇지 않으면 기존 체크포인트의 경로가 충돌하고 체크포인트 로드가 올바르지 않게 작동할 수 있습니다.
def record_scope(scope_name):
  """Record a variable_scope to make sure future ones get incremented."""
  with tf.compat.v1.variable_scope(scope_name):
    pass

class PartiallyNativeKerasLayersModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.keras.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.keras.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      record_scope('conv2d') # Only needed if follow-on compat.v1.layers do not pass a `name` arg
      out = self.conv_2(out)
      record_scope('conv2d_1') # Only needed if follow-on compat.v1.layers do not pass a `name` arg
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = PartiallyNativeKerasLayersModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_46200/3143218429.py:24: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['partially_native_keras_layers_model/model/conv2d_13/kernel:0',
 'partially_native_keras_layers_model/model/conv2d_13/bias:0',
 'partially_native_keras_layers_model/model/conv2d_14/kernel:0',
 'partially_native_keras_layers_model/model/conv2d_14/bias:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']

변수를 구성한 후 이 단계에서 체크포인트를 저장하면 현재 사용할 수 있는 객체 경로만 포함됩니다.

남은 compat.v1.layers의 자동 생성된 가중치 이름을 보존하려면 제거된 compat.v1.layers의 범위를 기록해야 합니다.

weights = set(v.name for v in layer.weights)
assert 'model/conv2d_2/kernel:0' in weights
assert 'model/conv2d_2/bias:0' in weights
  1. 모델의 모든 compat.v1.layerscompat.v1.get_variable을 완전한 네이티브 항목으로 교체할 때까지 위의 단계를 반복합니다.
class FullyNativeKerasLayersModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.keras.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.keras.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")
    self.conv_3 = tf.keras.layers.Conv2D(
          5, 5,
          kernel_regularizer="l2")


  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      out = self.conv_2(out)
      out = self.conv_3(out)
      return out

layer = FullyNativeKerasLayersModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
['fully_native_keras_layers_model/model/conv2d_16/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_16/bias:0',
 'fully_native_keras_layers_model/model/conv2d_17/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_17/bias:0',
 'fully_native_keras_layers_model/model/conv2d_18/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_18/bias:0']

새로 업데이트한 체크포인트가 계속 예상대로 작동하는지 테스트해야 합니다. 마이그레이션한 코드가 올바르게 실행되도록 이 프로세스의 모든 증분 단계에서 수치 정확성 검증 가이드에 설명된 기술을 적용합니다.

모델링 shim에서 다루지 않는 TF1.x에서 TF2 작업 변경 처리하기

이 가이드에 설명된 모델링 shim은 get_variable, tf.compat.v1.layers, variable_scope 의미 체계로 생성한 변수, 레이어 및 정규화 손실이 즉시 실행 및 tf.function을 사용할 때 컬렉션에 의존하지 않고 이전처럼 계속 작동하도록 할 수 있습니다.

여기에는 모델 순방향 전달이 의존할 수 있는 모든 TF1.x에 특화된 의미 체계가 포함되지 않습니다. 경우에 따라 shim이 자체적으로 TF2에서 실행되는 모델 순방향 전달을 가져오기에 충분하지 않을 수 있습니다. TF1.x와 TF2의 동작 차이점에 대해 자세히 알아보려면 TF1.x과 TF2 동작 차이 가이드를 읽어보세요.