사용자 정의 층

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

신경망을 구축하기 위해서 고수준 API인 tf.keras를 사용하길 권합니다. 대부분의 텐서플로 API는 즉시 실행(eager execution)과 함께 사용할 수 있습니다.

import tensorflow as tf
2022-12-14 22:21:57.768732: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:21:57.768835: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:21:57.768845: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
print(tf.config.list_physical_devices('GPU'))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU')]

층: 유용한 연산자 집합

머신러닝을 위한 코드를 작성하는 대부분의 경우에 개별적인 연산과 변수를 조작하는 것보다는 높은 수준의 추상화 도구를 사용할 것입니다.

많은 머신러닝 모델은 비교적 단순한 레이어를 조합하고 쌓아서 표현 가능합니다. 또한 TensorFlow는 여러 표준 레이어를 제공하므로 사용자 고유의 애플리케이션에 특화된 레이어를 처음부터 작성하거나, 기존 레이어의 조합으로 쉽게 만들 수 있습니다.

텐서플로는 케라스의 모든 API를 tf.keras 패키지에 포함하고 있습니다. 케라스 층은 모델을 구축하는데 매우 유용합니다.

# In the tf.keras.layers package, layers are objects. To construct a layer,
# simply construct the object. Most layers take as a first argument the number
# of output dimensions / channels.
layer = tf.keras.layers.Dense(100)
# The number of input dimensions is often unnecessary, as it can be inferred
# the first time the layer is used, but it can be provided if you want to
# specify it manually, which is useful in some complex models.
layer = tf.keras.layers.Dense(10, input_shape=(None, 5))

미리 구성되어있는 층은 다음 문서에서 확인할 수 있습니다. Dense(완전 연결 층), Conv2D, LSTM, BatchNormalization, Dropout, 등을 포함하고 있습니다.

# To use a layer, simply call it.
layer(tf.zeros([10, 5]))
<tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>
# Layers have many useful methods. For example, you can inspect all variables
# in a layer using `layer.variables` and trainable variables using
# `layer.trainable_variables`. In this case a fully-connected layer
# will have variables for weights and biases.
layer.variables
[<tf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=
 array([[ 0.0378511 ,  0.08353084,  0.05098164, -0.17555702, -0.37695765,
          0.32880777, -0.20828038, -0.22777048,  0.39718527, -0.17805666],
        [-0.04605299, -0.6141876 , -0.4751832 , -0.25455546, -0.07558787,
          0.43408698,  0.0988934 , -0.2672614 , -0.6142836 , -0.213664  ],
        [-0.59526837, -0.02855283, -0.39548466,  0.22484756,  0.57591397,
         -0.1012519 , -0.5177439 ,  0.5687112 ,  0.15866697, -0.61348975],
        [-0.25201526, -0.4648751 , -0.3889778 ,  0.45341247,  0.01259494,
         -0.074085  , -0.42965168,  0.21058083, -0.5529883 ,  0.06371927],
        [-0.4116999 ,  0.17359102,  0.1640541 , -0.35296482,  0.07844955,
         -0.06330568, -0.4817894 , -0.45749474, -0.5498795 ,  0.46872085]],
       dtype=float32)>,
 <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]
# The variables are also accessible through nice accessors
layer.kernel, layer.bias
(<tf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=
 array([[ 0.0378511 ,  0.08353084,  0.05098164, -0.17555702, -0.37695765,
          0.32880777, -0.20828038, -0.22777048,  0.39718527, -0.17805666],
        [-0.04605299, -0.6141876 , -0.4751832 , -0.25455546, -0.07558787,
          0.43408698,  0.0988934 , -0.2672614 , -0.6142836 , -0.213664  ],
        [-0.59526837, -0.02855283, -0.39548466,  0.22484756,  0.57591397,
         -0.1012519 , -0.5177439 ,  0.5687112 ,  0.15866697, -0.61348975],
        [-0.25201526, -0.4648751 , -0.3889778 ,  0.45341247,  0.01259494,
         -0.074085  , -0.42965168,  0.21058083, -0.5529883 ,  0.06371927],
        [-0.4116999 ,  0.17359102,  0.1640541 , -0.35296482,  0.07844955,
         -0.06330568, -0.4817894 , -0.45749474, -0.5498795 ,  0.46872085]],
       dtype=float32)>,
 <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>)

사용자 정의 층 구현

사용자 정의 층을 구현하는 가장 좋은 방법은 tf.keras.Layer 클래스를 상속하고 다음과 같이 구현하는 것입니다.

  1. __init__: 모든 입력 독립적 초기화를 수행할 수 있습니다.
  2. build: 입력 텐서의 형상을 알고 나머지 초기화 작업을 수행할 수 있습니다.
  3. call: 순방향 계산을 수행합니다.

변수를 생성하기 위해 build가 호출되길 기다릴 필요가 없다는 것에 주목하세요. 또한 변수를 __init__에 생성할 수도 있습니다. 그러나 build에 변수를 생성하는 유리한 점은 층이 작동할 입력의 크기를 기준으로 나중에 변수를 만들 수 있다는 것입니다. 반면에, __init__에 변수를 생성하는 것은 변수 생성에 필요한 크기가 명시적으로 지정되어야 함을 의미합니다.

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]),
                                         self.num_outputs])

  def call(self, inputs):
    return tf.matmul(inputs, self.kernel)

layer = MyDenseLayer(10)
_ = layer(tf.zeros([10, 5])) # Calling the layer `.builds` it.
print([var.name for var in layer.trainable_variables])
['my_dense_layer/kernel:0']

코드를 읽는 사람이 표준형 층의 동작을 잘 알고 있을 것이므로, 가능한 경우 표준형 층을 사용하는것이 전체 코드를 읽고 유지하는데 더 쉽습니다. 만약 tf.keras.layers 에 없는 층을 사용하기 원하면 깃허브에 이슈화하거나, 풀 리퀘스트(pull request)를 보내세요.

모델: 층 구성

머신러닝 모델에서 대부분의 재미있는 많은 것들은 기존의 층을 조합하여 구현됩니다. 예를 들어, 레즈넷(resnet)의 각 잔여 블록(residual block)은 합성곱(convolution), 배치 정규화(batch normalization), 쇼트컷(shortcut) 등으로 구성되어 있습니다.

Model.fit, Model.evaluateModel.save와 같은 모델 메서드가 필요할 때 일반적으로 keras.Model에서 상속합니다(자세한 내용은 사용자 지정 Keras 레이어 및 모델 참조).

keras.Model(keras.layers.Layer가 아니라)에 의해 제공되는 또 다른 특성은 변수를 추적하는 외에 keras.Model이 내부 레이어도 추적하여 검사하기 더 쉽게 해준다는 점입니다.

예를 들어 다음은 ResNet 블록입니다.

class ResnetIdentityBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters):
    super(ResnetIdentityBlock, self).__init__(name='')
    filters1, filters2, filters3 = filters

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
    self.bn2a = tf.keras.layers.BatchNormalization()

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
    self.bn2b = tf.keras.layers.BatchNormalization()

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
    self.bn2c = tf.keras.layers.BatchNormalization()

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    x += input_tensor
    return tf.nn.relu(x)


block = ResnetIdentityBlock(1, [1, 2, 3])
_ = block(tf.zeros([1, 2, 3, 3]))
block.layers
[<keras.layers.convolutional.conv2d.Conv2D at 0x7f75dc5f7a00>,
 <keras.layers.normalization.batch_normalization.BatchNormalization at 0x7f76e8377d00>,
 <keras.layers.convolutional.conv2d.Conv2D at 0x7f75d00371c0>,
 <keras.layers.normalization.batch_normalization.BatchNormalization at 0x7f75d0037550>,
 <keras.layers.convolutional.conv2d.Conv2D at 0x7f75500941f0>,
 <keras.layers.normalization.batch_normalization.BatchNormalization at 0x7f7550094220>]
len(block.variables)
18
block.summary()
Model: ""
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             multiple                  4         
                                                                 
 batch_normalization (BatchN  multiple                 4         
 ormalization)                                                   
                                                                 
 conv2d_1 (Conv2D)           multiple                  4         
                                                                 
 batch_normalization_1 (Batc  multiple                 8         
 hNormalization)                                                 
                                                                 
 conv2d_2 (Conv2D)           multiple                  9         
                                                                 
 batch_normalization_2 (Batc  multiple                 12        
 hNormalization)                                                 
                                                                 
=================================================================
Total params: 41
Trainable params: 29
Non-trainable params: 12
_________________________________________________________________

그러나 대부분의 경우에, 많은 층으로 구성된 모델은 단순하게 순서대로 층을 하나씩 호출합니다. 이는 tf.keras.Sequential 사용하여 간단한 코드로 구현 가능합니다.

my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1),
                                                    input_shape=(
                                                        None, None, 3)),
                             tf.keras.layers.BatchNormalization(),
                             tf.keras.layers.Conv2D(2, 1,
                                                    padding='same'),
                             tf.keras.layers.BatchNormalization(),
                             tf.keras.layers.Conv2D(3, (1, 1)),
                             tf.keras.layers.BatchNormalization()])
my_seq(tf.zeros([1, 2, 3, 3]))
<tf.Tensor: shape=(1, 2, 3, 3), dtype=float32, numpy=
array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]]], dtype=float32)>
my_seq.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_3 (Conv2D)           (None, None, None, 1)     4         
                                                                 
 batch_normalization_3 (Batc  (None, None, None, 1)    4         
 hNormalization)                                                 
                                                                 
 conv2d_4 (Conv2D)           (None, None, None, 2)     4         
                                                                 
 batch_normalization_4 (Batc  (None, None, None, 2)    8         
 hNormalization)                                                 
                                                                 
 conv2d_5 (Conv2D)           (None, None, None, 3)     9         
                                                                 
 batch_normalization_5 (Batc  (None, None, None, 3)    12        
 hNormalization)                                                 
                                                                 
=================================================================
Total params: 41
Trainable params: 29
Non-trainable params: 12
_________________________________________________________________

다음 단계

이제 이전 노트북으로 돌아가서 선형 회귀 예제에 층과 모델을 사용하여 좀 더 나은 구조를 적용할 수 있습니다.