カスタムレイヤー

コレクションでコンテンツを整理 必要に応じて、コンテンツの保存と分類を行います。

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

ニューラルネットワークの構築には、高レベルの API である tf.keras を使うことを推奨しますが、TensorFlow API のほとんどは、eager execution でも使用可能です。

import tensorflow as tf
2022-08-08 21:09:22.219990: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-08 21:09:22.996476: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-08 21:09:22.996740: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-08 21:09:22.996753: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
print(tf.config.list_physical_devices('GPU'))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU')]

レイヤー:有用な演算の共通セット

機械学習モデルのコーディングでは、個々の演算やひとつひとつの変数を操作するよりは、より高度に抽象化された演算を実行することが望ましい場合が多くあります。

多くの機械学習モデルは、比較的単純なレイヤーの組み合わせや積み重ねによって表現可能です。TensorFlow では、多くの一般的なレイヤーのセットに加えて、アプリケーションに特有なレイヤーを最初から記述したり、既存のレイヤーの組み合わせによって作るための、簡単な方法が提供されています。

TensorFlow では、tf.keras パッケージに Keras API のすべてが含まれています。Keras のレイヤーは、独自のモデルを構築する際に大変便利です。

# In the tf.keras.layers package, layers are objects. To construct a layer,
# simply construct the object. Most layers take as a first argument the number
# of output dimensions / channels.
layer = tf.keras.layers.Dense(100)
# The number of input dimensions is often unnecessary, as it can be inferred
# the first time the layer is used, but it can be provided if you want to
# specify it manually, which is useful in some complex models.
layer = tf.keras.layers.Dense(10, input_shape=(None, 5))

既存のレイヤーのすべての一覧は、ドキュメントを参照してください。Dense(全結合レイヤー)、Conv2D、LSTM、BatchNormalization、Dropoutなどのたくさんのレイヤーが含まれています。

# To use a layer, simply call it.
layer(tf.zeros([10, 5]))
<tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>
# Layers have many useful methods. For example, you can inspect all variables
# in a layer using `layer.variables` and trainable variables using
# `layer.trainable_variables`. In this case a fully-connected layer
# will have variables for weights and biases.
layer.variables
[<tf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=
 array([[ 0.17876357, -0.5827236 ,  0.49352247, -0.01141733, -0.59646183,
          0.11011076, -0.3342917 ,  0.48972362, -0.12729347,  0.11885774],
        [-0.61036223,  0.5851261 ,  0.26638865, -0.6307735 , -0.20873049,
          0.27125835, -0.5119239 ,  0.6118571 , -0.27632383, -0.2322486 ],
        [ 0.05983281,  0.60423213,  0.46980757,  0.12847233, -0.127895  ,
          0.34509546,  0.19301343,  0.31446487, -0.2130616 ,  0.27265364],
        [ 0.37125975,  0.488661  ,  0.0010255 , -0.31465158, -0.30079782,
          0.3839764 , -0.24688557, -0.2938036 ,  0.10192806, -0.3865546 ],
        [ 0.21430486, -0.555739  , -0.6182138 , -0.22973299,  0.02168632,
          0.6148847 , -0.29062256,  0.3071459 , -0.14586619, -0.06083453]],
       dtype=float32)>,
 <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]
# The variables are also accessible through nice accessors
layer.kernel, layer.bias
(<tf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=
 array([[ 0.17876357, -0.5827236 ,  0.49352247, -0.01141733, -0.59646183,
          0.11011076, -0.3342917 ,  0.48972362, -0.12729347,  0.11885774],
        [-0.61036223,  0.5851261 ,  0.26638865, -0.6307735 , -0.20873049,
          0.27125835, -0.5119239 ,  0.6118571 , -0.27632383, -0.2322486 ],
        [ 0.05983281,  0.60423213,  0.46980757,  0.12847233, -0.127895  ,
          0.34509546,  0.19301343,  0.31446487, -0.2130616 ,  0.27265364],
        [ 0.37125975,  0.488661  ,  0.0010255 , -0.31465158, -0.30079782,
          0.3839764 , -0.24688557, -0.2938036 ,  0.10192806, -0.3865546 ],
        [ 0.21430486, -0.555739  , -0.6182138 , -0.22973299,  0.02168632,
          0.6148847 , -0.29062256,  0.3071459 , -0.14586619, -0.06083453]],
       dtype=float32)>,
 <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>)

カスタムレイヤーの実装

独自のレイヤーを実装する最良の方法は、tf.keras.Layer クラスを拡張し、下記のメソッドを実装することです。

  1. __init__:入力に依存しないすべての初期化を実行できます
  2. build:入力テンソルの形状を知っている場合、残りの初期化を行うことができます。
  3. call:フォワード計算を行います。

build が呼ばれるまで変数の生成を待つ必要はなく、__init__ で作成できることに注意してください。しかしながら、build で変数を生成することの優位な点は、レイヤーが演算する入力の形状に基づいて、後から定義できる点です。これに対して、__init__ で変数を生成するには、必要な形状を明示的に指定する必要があります。

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]),
                                         self.num_outputs])

  def call(self, inputs):
    return tf.matmul(inputs, self.kernel)

layer = MyDenseLayer(10)
_ = layer(tf.zeros([10, 5])) # Calling the layer `.builds` it.
print([var.name for var in layer.trainable_variables])
['my_dense_layer/kernel:0']

できるだけ標準のレイヤーを使ったほうが、概してコードは読みやすく保守しやすくなります。コードを読む人は標準的なレイヤーの振る舞いに慣れているからです。tf.keras.layers にはないレイヤーを使うことを希望する場合には、github の課題を作成するか、プルリクエスト (推薦) を送ってください。

モデル:レイヤーの組み合わせ

機械学習では、多くのレイヤーに類するものが、既存のレイヤーを組み合わせることで実装されています。例えば、ResNet の残差ブロックは、畳込み、バッチ正規化とショートカットの組み合わせです。レイヤーは他のレイヤー内にネストできます。

通常、Model.fitModel.evaluate、および、Model.save などのモデルメソッドが必要な場合は、keras.Model から継承します。

keras.Modelにより提供されるもう 1 つの機能(keras.layers.Layerの代わりに)として、変数の追跡に加えて、keras.Modelもその内部レイヤーを追跡し、検査を容易にします。

たとえば、ResNet ブロックは次のとおりです。

class ResnetIdentityBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters):
    super(ResnetIdentityBlock, self).__init__(name='')
    filters1, filters2, filters3 = filters

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
    self.bn2a = tf.keras.layers.BatchNormalization()

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
    self.bn2b = tf.keras.layers.BatchNormalization()

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
    self.bn2c = tf.keras.layers.BatchNormalization()

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    x += input_tensor
    return tf.nn.relu(x)


block = ResnetIdentityBlock(1, [1, 2, 3])
_ = block(tf.zeros([1, 2, 3, 3]))
block.layers
[<keras.layers.convolutional.conv2d.Conv2D at 0x7f2379d7cee0>,
 <keras.layers.normalization.batch_normalization.BatchNormalization at 0x7f2484478700>,
 <keras.layers.convolutional.conv2d.Conv2D at 0x7f230c075520>,
 <keras.layers.normalization.batch_normalization.BatchNormalization at 0x7f2484478640>,
 <keras.layers.convolutional.conv2d.Conv2D at 0x7f230c075730>,
 <keras.layers.normalization.batch_normalization.BatchNormalization at 0x7f2379d41bb0>]
len(block.variables)
18
block.summary()
Model: ""
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             multiple                  4         
                                                                 
 batch_normalization (BatchN  multiple                 4         
 ormalization)                                                   
                                                                 
 conv2d_1 (Conv2D)           multiple                  4         
                                                                 
 batch_normalization_1 (Batc  multiple                 8         
 hNormalization)                                                 
                                                                 
 conv2d_2 (Conv2D)           multiple                  9         
                                                                 
 batch_normalization_2 (Batc  multiple                 12        
 hNormalization)                                                 
                                                                 
=================================================================
Total params: 41
Trainable params: 29
Non-trainable params: 12
_________________________________________________________________

しかし、ほとんどの場合には、モデルはレイヤーを次々に呼び出すことで構成されます。tf.keras.Sequential クラスを使うことで、これをかなり短いコードで実装できます。

my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1),
                                                    input_shape=(
                                                        None, None, 3)),
                             tf.keras.layers.BatchNormalization(),
                             tf.keras.layers.Conv2D(2, 1,
                                                    padding='same'),
                             tf.keras.layers.BatchNormalization(),
                             tf.keras.layers.Conv2D(3, (1, 1)),
                             tf.keras.layers.BatchNormalization()])
my_seq(tf.zeros([1, 2, 3, 3]))
<tf.Tensor: shape=(1, 2, 3, 3), dtype=float32, numpy=
array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]]], dtype=float32)>
my_seq.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_3 (Conv2D)           (None, None, None, 1)     4         
                                                                 
 batch_normalization_3 (Batc  (None, None, None, 1)    4         
 hNormalization)                                                 
                                                                 
 conv2d_4 (Conv2D)           (None, None, None, 2)     4         
                                                                 
 batch_normalization_4 (Batc  (None, None, None, 2)    8         
 hNormalization)                                                 
                                                                 
 conv2d_5 (Conv2D)           (None, None, None, 3)     9         
                                                                 
 batch_normalization_5 (Batc  (None, None, None, 3)    12        
 hNormalization)                                                 
                                                                 
=================================================================
Total params: 41
Trainable params: 29
Non-trainable params: 12
_________________________________________________________________

次のステップ

それでは、前のノートブックに戻り、レイヤーとモデルを使って、線形回帰の例をより構造化された形で実装してみてください。