自然グラフを用いた文書分類のためのグラフ正則化

View on TensorFlow.org Google Colab で実行 GitHub でソースを表示{

概要

グラフ正則化は、Neural Graph Learning(Bui et al.、2018)の広範なパラダイムに基づく固有の手法です。その中核となる考え方は、グラフ正則化された対象を持つニューラルネットワークモデルを、ラベル付けされたデータとラベル付けされていないデータの両方を使用してトレーニングすることです。

このチュートリアルでは、自然な(有機的な)グラフを形成する文書を分類するためにグラフ正則化を使用することについて見ていきます。

Neural Structured Learning(NSL)フレームワークを使用してグラフ正則化モデルを作成する、一般的な方策は以下の通りです。

  1. 入力グラフとサンプル特徴からトレーニングデータを生成します。グラフのノードはサンプルに対応し、グラフのエッジはサンプルのペア間の類似性に対応します。結果として得られるトレーニングデータには、元のノード特徴に加え、近傍特徴が含まれます。
  2. Keras Sequential API、Functional API、または Subclass API を使用して、基本モデルとしてニューラルネットワークを作成します。
  3. NSL フレームワークが提供する GraphRegularization ラッパークラスで基本モデルをラップし、新しいグラフ Keras モデルを作成します。この新しいモデルは、トレーニング目的の正則化項にグラフ正則化損失を含みます。
  4. グラフ Keras モデルをトレーニングして評価します。

セットアップ

Neural Structured Learning パッケージをインストールします。

pip install --quiet neural-structured-learning

依存関係とインポート

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version:  2.4.1
Eager mode:  True
GPU is available

Cora データセット

Cora データセットはノードが機械学習論文を表現し、エッジが論文のペア間の引用を表現する引用グラフです。このタスクには、各論文を 7 つのカテゴリのいずれかに分類することを目的とする文書分類が含まれています。言い換えると、これは 7 つのクラスを持つマルチクラス分類問題です。

グラフ

元のグラフは有向グラフです。しかし、この例の目的のためには無向グラフを考慮します。つまり、論文 A が論文 B を引用している場合、論文 B も論文 A を引用していると考えます。これは必ずしも正しいわけではありませんが、この例においては、引用を類似性のプロキシとみなしており、通常は可換性を持つとみなされます。

特徴

入力された各論文には、次の 2 つの特徴が効果的に含まれています。

  1. 単語:論文中のテキストを密でマルチホットな Bag of Words(BoW)表現にしたもの。Cora データセットの語彙には 1433 個のユニークな単語が含まれています。つまり、この特徴の長さは 1433 で、 'i' の位置の値は語彙中の単語 'i' が論文中に存在するかどうかを示す 0 か 1 です。

  2. ラベル: 論文のクラス ID(カテゴリ)を表現する単一の整数。

Cora データセットをダウンロードする

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

Cora データセットを NSL 形式に変換する

Cora データセットを前処理し、Neural Structured Learning に必要な形式に変換するために、NSL の GitHub リポジトリに含まれている 'preprocess_cora_dataset.py' スクリプトを実行します。このスクリプトは以下を行います。

  1. 元のノード特徴とグラフを使用して近傍特徴を生成します。
  2. tf.train.Example インスタンスを含むトレーニングデータとテストデータの分割を生成します。
  3. 結果として得られたトレーニングデータとテストデータを TFRecord 形式で永続化します。
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2021-02-12 22:29:54--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0.001s  

2021-02-12 22:29:54 (19.3 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2021-02-12 22:29:55.197371: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.01 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.36 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.04 minutes.

グローバル変数

トレーニングデータとテストデータへのファイルパスは、上記 'preprocess_cora_dataset.py' スクリプトの呼び出しに使用したコマンドラインフラグの値に基づきます。

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

ハイパーパラメータ

HParams のインスタンスを使用して、トレーニングと評価に使用する様々なハイパーパラメータと定数をインクルードします。それぞれについての簡単な説明を以下に示します。

  • num_classes:7 つの異なるクラスがあります。

  • max_seq_length:これは語彙のサイズであり、入力内のすべてのインスタンスは、密でマルチホットな Bag of Words 表現を持ちます。言い換えれば、ある単語の値が 1 ならば、その単語が入力内に存在することを示し、値が 0 ならば、その単語が入力内に存在しないことを示します。

  • distance_type:これはサンプルをその近傍と正則化する際に使用する距離メトリックです。

  • graph_regularization_multiplier:これは全体の損失関数においてグラフ正則化項の相対的な重みを制御します。

  • num_neighbors:グラフ正則化に使用する近傍の数を指定します。この値は上で preprocess_cora_dataset.py の実行時に使用したコマンドライン引数 max_nbrs 以下でなければなりません。

  • num_fc_units:ニューラルネットワーク内の完全に接続されたレイヤー数。

  • train_epochs:トレーニングのエポック数。

  • batch_size:トレーニングや評価に使用するバッチサイズ。

  • dropout_rate:完全に接続された各レイヤーの後のドロップアウト率を制御します。

  • eval_steps:評価が完了したと判断するまでに処理を行うバッチ数。None 設定にすると、テストセット内の全てのインスタンスを評価します。

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

トレーニングデータとテストデータを読み込む

先にこのノートブックで説明したように、入力のトレーニングデータとテストデータは 'preprocess_cora_dataset.py' によって作成されています。これらのデータを 2 つの tf.data.Dataset オブジェクトに読み込みます。1 つはトレーニング用、もう1 つはテスト用です。

モデルの入力レイヤー内では、各サンプルから「単語」と「ラベル」の特徴だけでなく、hparams.num_neighbors の値に基づき対応する近傍特徴も抽出します。近傍が hparams.num_neighbors よりも少ないインスタンスでは、存在しない近傍特徴にダミー値を割り当てます。

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

トレーニングデータセットの中身を覗いてみましょう。

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[1 3 3 6 3 3 2 0 3 2 6 3 2 1 2 3 2 3 1 1 2 0 5 2 2 2 1 2 2 2 0 2 0 1 1 6 6
 5 6 5 6 5 1 0 1 4 1 5 1 3 3 0 6 1 1 2 6 5 0 3 6 4 2 6 2 3 2 3 2 0 3 1 2 2
 0 2 3 4 1 2 0 4 6 2 4 3 3 4 0 1 3 3 3 6 2 6 1 1 2 0 3 3 2 5 4 4 1 3 1 3 5
 3 3 5 2 6 2 3 5 3 0 3 1 6 1 1 3 3], shape=(128,), dtype=int64)

テストデータセットの中身を覗いてみましょう。

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

モデルの定義

グラフ正則化の使い方を実証するために、まずこの問題の基本モデルを構築します。2 つの隠れレイヤーとその間にドロップアウトを持つ単純なフィードフォワード ニューラルネットワークを使用します。ここでは tf.Keras フレームワークでサポートされているすべてのモデルタイプ(Sequential モデル、Functional モデル、Subclass モデル)を使用して基本モデルを作成します。

Sequential 基本モデル

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
  return model

Functional 基本モデル

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(
      hparams.num_classes, activation='softmax')(
          cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

Subclass 基本モデル

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(
          hparams.num_classes, activation='softmax')

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

基本モデルを作成する

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
words (InputLayer)           [(None, 1433)]            0         
_________________________________________________________________
lambda (Lambda)              (None, 1433)              0         
_________________________________________________________________
dense (Dense)                (None, 50)                71700     
_________________________________________________________________
dropout (Dropout)            (None, 50)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 50)                2550      
_________________________________________________________________
dropout_1 (Dropout)          (None, 50)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 7)                 357       
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

基本 MLP モデルをトレーニングする

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:595: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model.
  [n for n in tensors.keys() if n not in ref_input_names])
17/17 [==============================] - 1s 11ms/step - loss: 1.9608 - accuracy: 0.1465
Epoch 2/100
17/17 [==============================] - 0s 11ms/step - loss: 1.8578 - accuracy: 0.2793
Epoch 3/100
17/17 [==============================] - 0s 11ms/step - loss: 1.7744 - accuracy: 0.3468
Epoch 4/100
17/17 [==============================] - 0s 10ms/step - loss: 1.6850 - accuracy: 0.3542
Epoch 5/100
17/17 [==============================] - 0s 11ms/step - loss: 1.5511 - accuracy: 0.4065
Epoch 6/100
17/17 [==============================] - 0s 11ms/step - loss: 1.3826 - accuracy: 0.5161
Epoch 7/100
17/17 [==============================] - 0s 11ms/step - loss: 1.2052 - accuracy: 0.5874
Epoch 8/100
17/17 [==============================] - 0s 11ms/step - loss: 1.0876 - accuracy: 0.6437
Epoch 9/100
17/17 [==============================] - 0s 10ms/step - loss: 0.9621 - accuracy: 0.6866
Epoch 10/100
17/17 [==============================] - 0s 11ms/step - loss: 0.8881 - accuracy: 0.7042
Epoch 11/100
17/17 [==============================] - 0s 11ms/step - loss: 0.8042 - accuracy: 0.7365
Epoch 12/100
17/17 [==============================] - 0s 11ms/step - loss: 0.7164 - accuracy: 0.7680
Epoch 13/100
17/17 [==============================] - 0s 11ms/step - loss: 0.6374 - accuracy: 0.8080
Epoch 14/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5826 - accuracy: 0.8164
Epoch 15/100
17/17 [==============================] - 0s 12ms/step - loss: 0.5169 - accuracy: 0.8426
Epoch 16/100
17/17 [==============================] - 0s 11ms/step - loss: 0.5486 - accuracy: 0.8348
Epoch 17/100
17/17 [==============================] - 0s 11ms/step - loss: 0.4695 - accuracy: 0.8565
Epoch 18/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4333 - accuracy: 0.8688
Epoch 19/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4156 - accuracy: 0.8735
Epoch 20/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3798 - accuracy: 0.8881
Epoch 21/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3469 - accuracy: 0.9021
Epoch 22/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3103 - accuracy: 0.9090
Epoch 23/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3284 - accuracy: 0.8891
Epoch 24/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2758 - accuracy: 0.9196
Epoch 25/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2780 - accuracy: 0.9124
Epoch 26/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2244 - accuracy: 0.9427
Epoch 27/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2555 - accuracy: 0.9215
Epoch 28/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2261 - accuracy: 0.9410
Epoch 29/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2545 - accuracy: 0.9228
Epoch 30/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2161 - accuracy: 0.9354
Epoch 31/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2065 - accuracy: 0.9445
Epoch 32/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2176 - accuracy: 0.9336
Epoch 33/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2013 - accuracy: 0.9421
Epoch 34/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1651 - accuracy: 0.9513
Epoch 35/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1848 - accuracy: 0.9514
Epoch 36/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1634 - accuracy: 0.9558
Epoch 37/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1605 - accuracy: 0.9598
Epoch 38/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1438 - accuracy: 0.9651
Epoch 39/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1602 - accuracy: 0.9569
Epoch 40/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1487 - accuracy: 0.9576
Epoch 41/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1324 - accuracy: 0.9742
Epoch 42/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1173 - accuracy: 0.9698
Epoch 43/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1148 - accuracy: 0.9690
Epoch 44/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1214 - accuracy: 0.9672
Epoch 45/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1289 - accuracy: 0.9645
Epoch 46/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1255 - accuracy: 0.9628
Epoch 47/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1151 - accuracy: 0.9697
Epoch 48/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1153 - accuracy: 0.9672
Epoch 49/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1074 - accuracy: 0.9681
Epoch 50/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1201 - accuracy: 0.9616
Epoch 51/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1033 - accuracy: 0.9784
Epoch 52/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0972 - accuracy: 0.9701
Epoch 53/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1056 - accuracy: 0.9733
Epoch 54/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1073 - accuracy: 0.9707
Epoch 55/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0989 - accuracy: 0.9705
Epoch 56/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0868 - accuracy: 0.9787
Epoch 57/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0956 - accuracy: 0.9745
Epoch 58/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0833 - accuracy: 0.9805
Epoch 59/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0848 - accuracy: 0.9805
Epoch 60/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1015 - accuracy: 0.9743
Epoch 61/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0894 - accuracy: 0.9735
Epoch 62/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0755 - accuracy: 0.9780
Epoch 63/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0736 - accuracy: 0.9793
Epoch 64/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0867 - accuracy: 0.9751
Epoch 65/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0708 - accuracy: 0.9783
Epoch 66/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0712 - accuracy: 0.9784
Epoch 67/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0678 - accuracy: 0.9816
Epoch 68/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0697 - accuracy: 0.9771
Epoch 69/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0960 - accuracy: 0.9764
Epoch 70/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0707 - accuracy: 0.9809
Epoch 71/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0818 - accuracy: 0.9771
Epoch 72/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0739 - accuracy: 0.9775
Epoch 73/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0710 - accuracy: 0.9796
Epoch 74/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0621 - accuracy: 0.9824
Epoch 75/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0567 - accuracy: 0.9881
Epoch 76/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0496 - accuracy: 0.9890
Epoch 77/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0609 - accuracy: 0.9837
Epoch 78/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0779 - accuracy: 0.9812
Epoch 79/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0591 - accuracy: 0.9837
Epoch 80/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0673 - accuracy: 0.9791
Epoch 81/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0568 - accuracy: 0.9839
Epoch 82/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0568 - accuracy: 0.9830
Epoch 83/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0635 - accuracy: 0.9830
Epoch 84/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0570 - accuracy: 0.9846
Epoch 85/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0617 - accuracy: 0.9854
Epoch 86/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0624 - accuracy: 0.9831
Epoch 87/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0462 - accuracy: 0.9884
Epoch 88/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0520 - accuracy: 0.9884
Epoch 89/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0468 - accuracy: 0.9875
Epoch 90/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0600 - accuracy: 0.9806
Epoch 91/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0507 - accuracy: 0.9823
Epoch 92/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0528 - accuracy: 0.9841
Epoch 93/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0483 - accuracy: 0.9865
Epoch 94/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0615 - accuracy: 0.9832
Epoch 95/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0470 - accuracy: 0.9856
Epoch 96/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0419 - accuracy: 0.9900
Epoch 97/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0407 - accuracy: 0.9942
Epoch 98/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0573 - accuracy: 0.9826
Epoch 99/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0465 - accuracy: 0.9877
Epoch 100/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0442 - accuracy: 0.9880
<tensorflow.python.keras.callbacks.History at 0x7f7860e3d048>

基本 MLP モデルを評価する

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 9ms/step - loss: 1.2943 - accuracy: 0.7939


Eval accuracy for  Base MLP model :  0.7938517332077026
Eval loss for  Base MLP model :  1.2943289279937744

MLP モデルをグラフ正則化でトレーニングする

既存の tf.Keras.Model の損失項にグラフ正則化を組み込む場合に必要なのは、数行のコードのみです。基本モデルをラップして、損失にグラフ正則化を含んだ新しい tf.Keras サブクラスモデルを作成します。

グラフ正則化の増分効果を評価するために、基本モデルの新しいインスタンスを作成します。これは、base_model は既に数回のイテレーションでトレーニングされているため、このトレーニング済みモデルを再利用してグラフ正則化モデルを作成しても base_model の公平な比較をすることができないからです。

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:437: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "shape. This may consume a large amount of memory." % value)
17/17 [==============================] - 2s 11ms/step - loss: 1.9542 - accuracy: 0.1519 - scaled_graph_loss: 7.4008e-04
Epoch 2/100
17/17 [==============================] - 0s 11ms/step - loss: 1.8780 - accuracy: 0.2452 - scaled_graph_loss: 8.6243e-04
Epoch 3/100
17/17 [==============================] - 0s 11ms/step - loss: 1.7961 - accuracy: 0.3197 - scaled_graph_loss: 0.0016
Epoch 4/100
17/17 [==============================] - 0s 10ms/step - loss: 1.6863 - accuracy: 0.3774 - scaled_graph_loss: 0.0032
Epoch 5/100
17/17 [==============================] - 0s 11ms/step - loss: 1.5712 - accuracy: 0.3973 - scaled_graph_loss: 0.0054
Epoch 6/100
17/17 [==============================] - 0s 11ms/step - loss: 1.4242 - accuracy: 0.4789 - scaled_graph_loss: 0.0087
Epoch 7/100
17/17 [==============================] - 0s 11ms/step - loss: 1.3093 - accuracy: 0.5452 - scaled_graph_loss: 0.0125
Epoch 8/100
17/17 [==============================] - 0s 11ms/step - loss: 1.1419 - accuracy: 0.6088 - scaled_graph_loss: 0.0169
Epoch 9/100
17/17 [==============================] - 0s 11ms/step - loss: 1.0283 - accuracy: 0.6588 - scaled_graph_loss: 0.0207
Epoch 10/100
17/17 [==============================] - 0s 11ms/step - loss: 0.9211 - accuracy: 0.7076 - scaled_graph_loss: 0.0243
Epoch 11/100
17/17 [==============================] - 0s 11ms/step - loss: 0.8022 - accuracy: 0.7699 - scaled_graph_loss: 0.0262
Epoch 12/100
17/17 [==============================] - 0s 11ms/step - loss: 0.7787 - accuracy: 0.7628 - scaled_graph_loss: 0.0284
Epoch 13/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6991 - accuracy: 0.7949 - scaled_graph_loss: 0.0298
Epoch 14/100
17/17 [==============================] - 0s 11ms/step - loss: 0.6366 - accuracy: 0.8353 - scaled_graph_loss: 0.0298
Epoch 15/100
17/17 [==============================] - 0s 11ms/step - loss: 0.5447 - accuracy: 0.8312 - scaled_graph_loss: 0.0316
Epoch 16/100
17/17 [==============================] - 0s 11ms/step - loss: 0.5165 - accuracy: 0.8604 - scaled_graph_loss: 0.0295
Epoch 17/100
17/17 [==============================] - 0s 11ms/step - loss: 0.4780 - accuracy: 0.8717 - scaled_graph_loss: 0.0307
Epoch 18/100
17/17 [==============================] - 0s 11ms/step - loss: 0.4786 - accuracy: 0.8763 - scaled_graph_loss: 0.0304
Epoch 19/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4446 - accuracy: 0.8762 - scaled_graph_loss: 0.0328
Epoch 20/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3954 - accuracy: 0.8953 - scaled_graph_loss: 0.0322
Epoch 21/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3739 - accuracy: 0.8967 - scaled_graph_loss: 0.0320
Epoch 22/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3835 - accuracy: 0.9009 - scaled_graph_loss: 0.0329
Epoch 23/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3242 - accuracy: 0.9201 - scaled_graph_loss: 0.0330
Epoch 24/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3034 - accuracy: 0.9214 - scaled_graph_loss: 0.0310
Epoch 25/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2909 - accuracy: 0.9281 - scaled_graph_loss: 0.0345
Epoch 26/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2921 - accuracy: 0.9249 - scaled_graph_loss: 0.0347
Epoch 27/100
17/17 [==============================] - 0s 12ms/step - loss: 0.2439 - accuracy: 0.9483 - scaled_graph_loss: 0.0335
Epoch 28/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2524 - accuracy: 0.9445 - scaled_graph_loss: 0.0330
Epoch 29/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2310 - accuracy: 0.9424 - scaled_graph_loss: 0.0319
Epoch 30/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2389 - accuracy: 0.9388 - scaled_graph_loss: 0.0334
Epoch 31/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2204 - accuracy: 0.9523 - scaled_graph_loss: 0.0355
Epoch 32/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2159 - accuracy: 0.9525 - scaled_graph_loss: 0.0334
Epoch 33/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2022 - accuracy: 0.9561 - scaled_graph_loss: 0.0345
Epoch 34/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1926 - accuracy: 0.9601 - scaled_graph_loss: 0.0345
Epoch 35/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2049 - accuracy: 0.9493 - scaled_graph_loss: 0.0343
Epoch 36/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1732 - accuracy: 0.9627 - scaled_graph_loss: 0.0335
Epoch 37/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1914 - accuracy: 0.9573 - scaled_graph_loss: 0.0327
Epoch 38/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1781 - accuracy: 0.9578 - scaled_graph_loss: 0.0332
Epoch 39/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1650 - accuracy: 0.9730 - scaled_graph_loss: 0.0324
Epoch 40/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1650 - accuracy: 0.9621 - scaled_graph_loss: 0.0328
Epoch 41/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1721 - accuracy: 0.9644 - scaled_graph_loss: 0.0339
Epoch 42/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1672 - accuracy: 0.9687 - scaled_graph_loss: 0.0356
Epoch 43/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1642 - accuracy: 0.9600 - scaled_graph_loss: 0.0343
Epoch 44/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1469 - accuracy: 0.9735 - scaled_graph_loss: 0.0334
Epoch 45/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1478 - accuracy: 0.9708 - scaled_graph_loss: 0.0340
Epoch 46/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1537 - accuracy: 0.9640 - scaled_graph_loss: 0.0367
Epoch 47/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1513 - accuracy: 0.9691 - scaled_graph_loss: 0.0355
Epoch 48/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1252 - accuracy: 0.9768 - scaled_graph_loss: 0.0327
Epoch 49/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1443 - accuracy: 0.9722 - scaled_graph_loss: 0.0352
Epoch 50/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1339 - accuracy: 0.9731 - scaled_graph_loss: 0.0333
Epoch 51/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1385 - accuracy: 0.9741 - scaled_graph_loss: 0.0362
Epoch 52/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1347 - accuracy: 0.9732 - scaled_graph_loss: 0.0333
Epoch 53/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1222 - accuracy: 0.9785 - scaled_graph_loss: 0.0353
Epoch 54/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1258 - accuracy: 0.9738 - scaled_graph_loss: 0.0354
Epoch 55/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1209 - accuracy: 0.9771 - scaled_graph_loss: 0.0352
Epoch 56/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1279 - accuracy: 0.9787 - scaled_graph_loss: 0.0352
Epoch 57/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1273 - accuracy: 0.9719 - scaled_graph_loss: 0.0312
Epoch 58/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1206 - accuracy: 0.9747 - scaled_graph_loss: 0.0332
Epoch 59/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1109 - accuracy: 0.9814 - scaled_graph_loss: 0.0342
Epoch 60/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1168 - accuracy: 0.9778 - scaled_graph_loss: 0.0338
Epoch 61/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1125 - accuracy: 0.9820 - scaled_graph_loss: 0.0341
Epoch 62/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1055 - accuracy: 0.9824 - scaled_graph_loss: 0.0359
Epoch 63/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1183 - accuracy: 0.9771 - scaled_graph_loss: 0.0361
Epoch 64/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1063 - accuracy: 0.9835 - scaled_graph_loss: 0.0343
Epoch 65/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1117 - accuracy: 0.9786 - scaled_graph_loss: 0.0306
Epoch 66/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1091 - accuracy: 0.9783 - scaled_graph_loss: 0.0343
Epoch 67/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9882 - scaled_graph_loss: 0.0340
Epoch 68/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1077 - accuracy: 0.9842 - scaled_graph_loss: 0.0366
Epoch 69/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1092 - accuracy: 0.9767 - scaled_graph_loss: 0.0353
Epoch 70/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1159 - accuracy: 0.9777 - scaled_graph_loss: 0.0338
Epoch 71/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0875 - accuracy: 0.9881 - scaled_graph_loss: 0.0325
Epoch 72/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0905 - accuracy: 0.9864 - scaled_graph_loss: 0.0337
Epoch 73/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1021 - accuracy: 0.9767 - scaled_graph_loss: 0.0321
Epoch 74/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1047 - accuracy: 0.9773 - scaled_graph_loss: 0.0328
Epoch 75/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0958 - accuracy: 0.9812 - scaled_graph_loss: 0.0338
Epoch 76/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0997 - accuracy: 0.9802 - scaled_graph_loss: 0.0335
Epoch 77/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0853 - accuracy: 0.9877 - scaled_graph_loss: 0.0314
Epoch 78/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1016 - accuracy: 0.9810 - scaled_graph_loss: 0.0346
Epoch 79/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0978 - accuracy: 0.9809 - scaled_graph_loss: 0.0317
Epoch 80/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0908 - accuracy: 0.9864 - scaled_graph_loss: 0.0329
Epoch 81/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0742 - accuracy: 0.9902 - scaled_graph_loss: 0.0332
Epoch 82/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0910 - accuracy: 0.9875 - scaled_graph_loss: 0.0345
Epoch 83/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0908 - accuracy: 0.9848 - scaled_graph_loss: 0.0345
Epoch 84/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0848 - accuracy: 0.9831 - scaled_graph_loss: 0.0328
Epoch 85/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0997 - accuracy: 0.9804 - scaled_graph_loss: 0.0345
Epoch 86/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0901 - accuracy: 0.9859 - scaled_graph_loss: 0.0326
Epoch 87/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0825 - accuracy: 0.9873 - scaled_graph_loss: 0.0334
Epoch 88/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0805 - accuracy: 0.9885 - scaled_graph_loss: 0.0332
Epoch 89/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0776 - accuracy: 0.9885 - scaled_graph_loss: 0.0330
Epoch 90/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0960 - accuracy: 0.9799 - scaled_graph_loss: 0.0341
Epoch 91/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0787 - accuracy: 0.9888 - scaled_graph_loss: 0.0337
Epoch 92/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0736 - accuracy: 0.9914 - scaled_graph_loss: 0.0348
Epoch 93/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0806 - accuracy: 0.9892 - scaled_graph_loss: 0.0347
Epoch 94/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0723 - accuracy: 0.9912 - scaled_graph_loss: 0.0314
Epoch 95/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0671 - accuracy: 0.9887 - scaled_graph_loss: 0.0295
Epoch 96/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0780 - accuracy: 0.9887 - scaled_graph_loss: 0.0327
Epoch 97/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0843 - accuracy: 0.9871 - scaled_graph_loss: 0.0331
Epoch 98/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0796 - accuracy: 0.9901 - scaled_graph_loss: 0.0333
Epoch 99/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0871 - accuracy: 0.9847 - scaled_graph_loss: 0.0329
Epoch 100/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0787 - accuracy: 0.9859 - scaled_graph_loss: 0.0335
<tensorflow.python.keras.callbacks.History at 0x7f786083d518>

MLP モデルをグラフ正則化で評価する

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 8ms/step - loss: 1.2281 - accuracy: 0.8103


Eval accuracy for  MLP + graph regularization :  0.8155515193939209
Eval loss for  MLP + graph regularization :  1.2275753021240234

グラフ正則化モデルの精度は、基本モデル(base_model)の精度に比べて 2-3% 程度高くなります。

結論

Neural Structured Learning(NSL)フレームワークを用いて、自然な引用グラフ上の文書(Cora)分類のためのグラフ正則化の使用について実証しました。上級者向けチュートリアルでは、グラフ正則化を使用してニューラルネットワークをトレーニングする前に、サンプル埋め込みに基づいたグラフを合成します。このアプローチは、入力に明示的なグラフが含まれない場合に有用です。

ユーザーの方々には、グラフ正則化のさまざまなニューラルアーキテクチャを試してみると共に、監視の量を加減してさらに深く実験を行うことを推奨しています。