このページは Cloud Translation API によって翻訳されました。
Switch to English

自然グラフを使用したドキュメント分類のグラフ正則化

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する

概観

グラフの正則化は、ニューラルグラフ学習( Bui et al。、2018 )のより広いパラダイムの下での特定の手法です。核となるアイデアは、ラベル付きデータとラベルなしデータの両方を利用して、グラフを正規化した目的でニューラルネットワークモデルをトレーニングすることです。

このチュートリアルでは、グラフの正規化を使用して、自然な(有機的な)グラフを形成するドキュメントを分類する方法を探ります。

ニューラル構造化学習(NSL)フレームワークを使用してグラフ正則化モデルを作成するための一般的なレシピは、次のとおりです。

  1. 入力グラフとサンプル特徴からトレーニングデータを生成します。グラフのノードはサンプルに対応し、グラフのエッジはサンプルのペア間の類似性に対応します。結果のトレーニングデータには、元のノードフィーチャに加えて隣接フィーチャが含まれます。
  2. Keras順次、機能、またはサブクラスAPIを使用して、ニューラルネットワークを基本モデルとして作成します。
  3. NSLフレームワークによって提供されるGraphRegularizationラッパークラスで基本モデルをラップして、新しいグラフKerasモデルを作成します。この新しいモデルには、トレーニング目標の正規化項としてグラフの正規化損失が含まれます。
  4. グラフKerasモデルをKeras評価します。

セットアップ

ニューラル構造化学習パッケージをインストールします。

pip install --quiet neural-structured-learning

依存関係とインポート

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version:  2.2.0
Eager mode:  True
GPU is NOT AVAILABLE

コーラデータセット

Coraデータセットは引用グラフで、ノードは機械学習論文を表し、エッジは論文のペア間の引用を表します。関係するタスクは、ドキュメントの分類で、各論文を7つのカテゴリのいずれかに分類することを目的としています。言い換えれば、これは7つのクラスを持つマルチクラス分類問題です。

グラフ

元のグラフが指示されます。ただし、この例では、このグラフの無向バージョンを検討します。したがって、論文Aが論文Bを引用している場合、論文BもAを引用していると見なします。これは必ずしも真実ではありませんが、この例では、引用は類似性の代用と見なされます。

特徴

入力の各ペーパーには、実質的に2つの機能が含まれています。

  1. 単語 :紙のテキストの高密度で複数のホットバッグオブワード表現。 Coraデータセットの語彙には、1433の一意の単語が含まれています。したがって、この特徴の長さは1433であり、位置「i」の値は0/1で、語彙の単語「i」が特定の論文に存在するかどうかを示します。

  2. ラベル :論文のクラスID(カテゴリ)を表す単一の整数。

Coraデータセットをダウンロードする

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

CoraデータをNSL形式に変換する

Coraデータセットを前処理して、Neural Structured Learningで必要な形式に変換するために、NSL githubリポジトリに含まれている「preprocess_cora_dataset.py」スクリプトを実行します。このスクリプトは次のことを行います。

  1. 元のノードフィーチャとグラフを使用して隣接フィーチャを生成します。
  2. tf.train.Exampleインスタンスを含むtf.train.Exampleデータとテストデータの分割をtf.train.Exampleます。
  3. 結果のトレインとテストデータをTFRecord形式でTFRecordます。
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2020-07-01 11:15:33--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.128.133, 151.101.64.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0s      

2020-07-01 11:15:33 (84.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.06 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.38 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.04 minutes.

グローバル変数

トレインおよびテストデータへのファイルパスは、上記の「preprocess_cora_dataset.py」スクリプトを呼び出すために使用されるコマンドラインフラグ値に基づいています。

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

ハイパーパラメータ

HParamsインスタンスを使用して、トレーニングと評価に使用されるさまざまなハイパーHParamsと定数を含めます。以下に、それぞれについて簡単に説明します。

  • num_classes :合計7つのクラスがあります

  • max_seq_length :これは語彙のサイズであり、入力内のすべてのインスタンスは密なマルチホットのバッグオブワード表現を持っています。つまり、単語の値が1の場合、その単語は入力に存在することを示し、値が0の場合は存在しないことを示します。

  • distance_type :これは、サンプルをその近傍と正規化するために使用される距離メトリックです。

  • graph_regularization_multiplier :これは、全体的な損失関数におけるグラフ正則化項の相対的な重みを制御します。

  • num_neighbors :グラフの正則化に使用されるネイバーの数。この値未満またはそれに等しくなるように有するmax_nbrs実行時には、上記使用コマンドライン引数preprocess_cora_dataset.py

  • num_fc_units :ニューラルネットワークで完全に接続されたレイヤーの数。

  • train_epochs :トレーニングエポックの数。

  • batch_size :トレーニングと評価に使用されるバッチサイズ。

  • dropout_rate :完全に接続された各レイヤーに続くドロップアウトのレートを制御します

  • eval_steps :評価が完了したと見なす前に処理するバッチの数。 Noneに設定すると、テストセットのすべてのインスタンスが評価されます。

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

ロードトレインとテストデータ

このノートブックで前に説明したように、入力トレーニングデータとテストデータは'preprocess_cora_dataset.py'によって作成されています。それらを2つのtf.data.Datasetオブジェクトにロードします。1つはトレーニング用、もう1つはテスト用です。

モデルの入力レイヤーでは、各サンプルから「単語」と「ラベル」の特徴だけでなく、 hparams.num_neighbors値に基づいて対応する隣接特徴もhparams.num_neighborsます。 hparams.num_neighborsよりもhparams.num_neighborsが少ないインスタンスには、存在しないネイバーフィーチャのダミー値が割り当てられます。

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

trainデータセットを覗いて、その内容を見てみましょう。

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[4 3 1 2 1 6 2 5 6 2 2 6 5 0 2 2 1 6 2 2 2 2 5 4 2 0 2 1 1 2 0 5 2 2 2 0 2
 2 0 6 1 1 0 2 1 2 3 2 0 0 0 4 1 3 3 1 2 5 3 3 1 1 6 0 0 4 6 5 6 0 3 4 2 2
 2 3 3 2 4 0 2 3 2 2 3 1 2 2 1 0 6 1 2 1 6 2 1 0 4 3 2 5 2 3 1 0 3 4 3 4 1
 0 5 6 4 2 1 1 2 5 3 4 3 1 3 2 6 3], shape=(128,), dtype=int64)

テストデータセットを調べて、その内容を見てみましょう。

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

モデル定義

グラフの正則化の使用法を示すために、最初にこの問題の基本モデルを構築します。 2つの非表示レイヤーとその間のドロップアウトを持つ単純なフィードフォワードニューラルネットワークを使用します。 tf.Kerasフレームワークでサポートされているすべてのモデルタイプ(シーケンシャル、ファンクション、サブクラス)を使用したベースモデルの作成について説明します。

シーケンシャルベースモデル

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
  return model

機能ベースモデル

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(
      hparams.num_classes, activation='softmax')(
          cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

サブクラスのベースモデル

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(
          hparams.num_classes, activation='softmax')

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

基本モデルを作成する

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
words (InputLayer)           [(None, 1433)]            0         
_________________________________________________________________
lambda (Lambda)              (None, 1433)              0         
_________________________________________________________________
dense (Dense)                (None, 50)                71700     
_________________________________________________________________
dropout (Dropout)            (None, 50)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 50)                2550      
_________________________________________________________________
dropout_1 (Dropout)          (None, 50)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 7)                 357       
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

ベースMLPモデルのトレーニング

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
17/17 [==============================] - 0s 11ms/step - loss: 1.9256 - accuracy: 0.1870
Epoch 2/100
17/17 [==============================] - 0s 10ms/step - loss: 1.8410 - accuracy: 0.2835
Epoch 3/100
17/17 [==============================] - 0s 9ms/step - loss: 1.7479 - accuracy: 0.3374
Epoch 4/100
17/17 [==============================] - 0s 10ms/step - loss: 1.6384 - accuracy: 0.3884
Epoch 5/100
17/17 [==============================] - 0s 9ms/step - loss: 1.5086 - accuracy: 0.4390
Epoch 6/100
17/17 [==============================] - 0s 10ms/step - loss: 1.3606 - accuracy: 0.5016
Epoch 7/100
17/17 [==============================] - 0s 9ms/step - loss: 1.2165 - accuracy: 0.5791
Epoch 8/100
17/17 [==============================] - 0s 10ms/step - loss: 1.0783 - accuracy: 0.6311
Epoch 9/100
17/17 [==============================] - 0s 9ms/step - loss: 0.9552 - accuracy: 0.6947
Epoch 10/100
17/17 [==============================] - 0s 9ms/step - loss: 0.8680 - accuracy: 0.7090
Epoch 11/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7915 - accuracy: 0.7425
Epoch 12/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7124 - accuracy: 0.7773
Epoch 13/100
17/17 [==============================] - 0s 9ms/step - loss: 0.6582 - accuracy: 0.7907
Epoch 14/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6021 - accuracy: 0.8065
Epoch 15/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5416 - accuracy: 0.8325
Epoch 16/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5042 - accuracy: 0.8473
Epoch 17/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4433 - accuracy: 0.8761
Epoch 18/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4310 - accuracy: 0.8640
Epoch 19/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3894 - accuracy: 0.8840
Epoch 20/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3676 - accuracy: 0.8891
Epoch 21/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3576 - accuracy: 0.8812
Epoch 22/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3132 - accuracy: 0.9067
Epoch 23/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3058 - accuracy: 0.9142
Epoch 24/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2924 - accuracy: 0.9155
Epoch 25/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2769 - accuracy: 0.9197
Epoch 26/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2636 - accuracy: 0.9244
Epoch 27/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2429 - accuracy: 0.9313
Epoch 28/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2324 - accuracy: 0.9323
Epoch 29/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2285 - accuracy: 0.9346
Epoch 30/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2039 - accuracy: 0.9374
Epoch 31/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1943 - accuracy: 0.9471
Epoch 32/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1898 - accuracy: 0.9439
Epoch 33/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1879 - accuracy: 0.9425
Epoch 34/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1828 - accuracy: 0.9443
Epoch 35/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1635 - accuracy: 0.9541
Epoch 36/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1648 - accuracy: 0.9476
Epoch 37/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1603 - accuracy: 0.9499
Epoch 38/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1428 - accuracy: 0.9624
Epoch 39/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1483 - accuracy: 0.9601
Epoch 40/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1352 - accuracy: 0.9582
Epoch 41/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1379 - accuracy: 0.9555
Epoch 42/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1410 - accuracy: 0.9582
Epoch 43/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1198 - accuracy: 0.9684
Epoch 44/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1148 - accuracy: 0.9731
Epoch 45/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1228 - accuracy: 0.9657
Epoch 46/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1135 - accuracy: 0.9703
Epoch 47/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1134 - accuracy: 0.9661
Epoch 48/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1175 - accuracy: 0.9619
Epoch 49/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1002 - accuracy: 0.9703
Epoch 50/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1143 - accuracy: 0.9671
Epoch 51/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0923 - accuracy: 0.9777
Epoch 52/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1068 - accuracy: 0.9731
Epoch 53/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0972 - accuracy: 0.9712
Epoch 54/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0828 - accuracy: 0.9796
Epoch 55/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1036 - accuracy: 0.9703
Epoch 56/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0954 - accuracy: 0.9745
Epoch 57/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0883 - accuracy: 0.9768
Epoch 58/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0859 - accuracy: 0.9777
Epoch 59/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0856 - accuracy: 0.9759
Epoch 60/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0858 - accuracy: 0.9754
Epoch 61/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0848 - accuracy: 0.9726
Epoch 62/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0840 - accuracy: 0.9763
Epoch 63/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0770 - accuracy: 0.9805
Epoch 64/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0823 - accuracy: 0.9745
Epoch 65/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0665 - accuracy: 0.9828
Epoch 66/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0788 - accuracy: 0.9777
Epoch 67/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0690 - accuracy: 0.9800
Epoch 68/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0683 - accuracy: 0.9805
Epoch 69/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0615 - accuracy: 0.9838
Epoch 70/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0618 - accuracy: 0.9833
Epoch 71/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0659 - accuracy: 0.9810
Epoch 72/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0704 - accuracy: 0.9800
Epoch 73/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0645 - accuracy: 0.9814
Epoch 74/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0645 - accuracy: 0.9791
Epoch 75/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0638 - accuracy: 0.9791
Epoch 76/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0648 - accuracy: 0.9814
Epoch 77/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0591 - accuracy: 0.9838
Epoch 78/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0606 - accuracy: 0.9861
Epoch 79/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0699 - accuracy: 0.9814
Epoch 80/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0603 - accuracy: 0.9828
Epoch 81/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0629 - accuracy: 0.9828
Epoch 82/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0596 - accuracy: 0.9828
Epoch 83/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0542 - accuracy: 0.9828
Epoch 84/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0452 - accuracy: 0.9893
Epoch 85/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0551 - accuracy: 0.9838
Epoch 86/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0555 - accuracy: 0.9842
Epoch 87/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0514 - accuracy: 0.9824
Epoch 88/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0553 - accuracy: 0.9847
Epoch 89/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0475 - accuracy: 0.9884
Epoch 90/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0476 - accuracy: 0.9893
Epoch 91/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0427 - accuracy: 0.9903
Epoch 92/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0475 - accuracy: 0.9847
Epoch 93/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0423 - accuracy: 0.9893
Epoch 94/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0473 - accuracy: 0.9865
Epoch 95/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0560 - accuracy: 0.9819
Epoch 96/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0547 - accuracy: 0.9810
Epoch 97/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0576 - accuracy: 0.9814
Epoch 98/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0429 - accuracy: 0.9893
Epoch 99/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0440 - accuracy: 0.9875
Epoch 100/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0513 - accuracy: 0.9838

<tensorflow.python.keras.callbacks.History at 0x7fc47a3c78d0>

ベースMLPモデルを評価する

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.3380 - accuracy: 0.7740


Eval accuracy for  Base MLP model :  0.7739602327346802
Eval loss for  Base MLP model :  1.3379606008529663

グラフの正則化を使用してMLPモデルをトレーニングする

グラフの正則化を既存のtf.Keras.Model損失項にtf.Keras.Modelは、数行のコードが必要です。基本モデルはラップされて、新しいtf.Kerasサブクラスモデルを作成します。その損失にはグラフの正則化が含まれます。

グラフの正則化の段階的な利点を評価するために、新しいベースモデルインスタンスを作成します。これは、 base_modelがすでにいくつかの反復でトレーニングされており、このトレーニングされたモデルを再利用してグラフ正規化モデルを作成することは、 base_model公平な比較にならないためbase_model

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100

/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:434: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

17/17 [==============================] - 0s 10ms/step - loss: 1.9454 - accuracy: 0.1652 - graph_loss: 0.0076
Epoch 2/100
17/17 [==============================] - 0s 10ms/step - loss: 1.8517 - accuracy: 0.2956 - graph_loss: 0.0117
Epoch 3/100
17/17 [==============================] - 0s 10ms/step - loss: 1.7589 - accuracy: 0.3151 - graph_loss: 0.0261
Epoch 4/100
17/17 [==============================] - 0s 10ms/step - loss: 1.6714 - accuracy: 0.3392 - graph_loss: 0.0476
Epoch 5/100
17/17 [==============================] - 0s 9ms/step - loss: 1.5607 - accuracy: 0.4037 - graph_loss: 0.0622
Epoch 6/100
17/17 [==============================] - 0s 10ms/step - loss: 1.4486 - accuracy: 0.4807 - graph_loss: 0.0921
Epoch 7/100
17/17 [==============================] - 0s 10ms/step - loss: 1.3135 - accuracy: 0.5383 - graph_loss: 0.1236
Epoch 8/100
17/17 [==============================] - 0s 10ms/step - loss: 1.1902 - accuracy: 0.5912 - graph_loss: 0.1616
Epoch 9/100
17/17 [==============================] - 0s 10ms/step - loss: 1.0647 - accuracy: 0.6575 - graph_loss: 0.1920
Epoch 10/100
17/17 [==============================] - 0s 9ms/step - loss: 0.9416 - accuracy: 0.7067 - graph_loss: 0.2181
Epoch 11/100
17/17 [==============================] - 0s 10ms/step - loss: 0.8601 - accuracy: 0.7378 - graph_loss: 0.2470
Epoch 12/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7968 - accuracy: 0.7462 - graph_loss: 0.2565
Epoch 13/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6881 - accuracy: 0.7912 - graph_loss: 0.2681
Epoch 14/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6548 - accuracy: 0.8139 - graph_loss: 0.2941
Epoch 15/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5874 - accuracy: 0.8376 - graph_loss: 0.3010
Epoch 16/100
17/17 [==============================] - 0s 9ms/step - loss: 0.5537 - accuracy: 0.8348 - graph_loss: 0.3014
Epoch 17/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5123 - accuracy: 0.8529 - graph_loss: 0.3097
Epoch 18/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4771 - accuracy: 0.8640 - graph_loss: 0.3192
Epoch 19/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4294 - accuracy: 0.8826 - graph_loss: 0.3182
Epoch 20/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4109 - accuracy: 0.8854 - graph_loss: 0.3169
Epoch 21/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3901 - accuracy: 0.8965 - graph_loss: 0.3250
Epoch 22/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3700 - accuracy: 0.8956 - graph_loss: 0.3349
Epoch 23/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3716 - accuracy: 0.8974 - graph_loss: 0.3408
Epoch 24/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3258 - accuracy: 0.9202 - graph_loss: 0.3361
Epoch 25/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3043 - accuracy: 0.9253 - graph_loss: 0.3351
Epoch 26/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2919 - accuracy: 0.9253 - graph_loss: 0.3361
Epoch 27/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3005 - accuracy: 0.9202 - graph_loss: 0.3249
Epoch 28/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2629 - accuracy: 0.9336 - graph_loss: 0.3442
Epoch 29/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2617 - accuracy: 0.9401 - graph_loss: 0.3302
Epoch 30/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2510 - accuracy: 0.9383 - graph_loss: 0.3436
Epoch 31/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2452 - accuracy: 0.9411 - graph_loss: 0.3364
Epoch 32/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2397 - accuracy: 0.9466 - graph_loss: 0.3333
Epoch 33/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2239 - accuracy: 0.9466 - graph_loss: 0.3373
Epoch 34/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2084 - accuracy: 0.9513 - graph_loss: 0.3330
Epoch 35/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2075 - accuracy: 0.9499 - graph_loss: 0.3383
Epoch 36/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2064 - accuracy: 0.9513 - graph_loss: 0.3394
Epoch 37/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1857 - accuracy: 0.9568 - graph_loss: 0.3371
Epoch 38/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1799 - accuracy: 0.9601 - graph_loss: 0.3477
Epoch 39/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1844 - accuracy: 0.9573 - graph_loss: 0.3385
Epoch 40/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1823 - accuracy: 0.9592 - graph_loss: 0.3445
Epoch 41/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1713 - accuracy: 0.9615 - graph_loss: 0.3451
Epoch 42/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1669 - accuracy: 0.9624 - graph_loss: 0.3398
Epoch 43/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1692 - accuracy: 0.9671 - graph_loss: 0.3483
Epoch 44/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1605 - accuracy: 0.9647 - graph_loss: 0.3437
Epoch 45/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1485 - accuracy: 0.9703 - graph_loss: 0.3338
Epoch 46/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1467 - accuracy: 0.9717 - graph_loss: 0.3405
Epoch 47/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1492 - accuracy: 0.9694 - graph_loss: 0.3466
Epoch 48/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1577 - accuracy: 0.9666 - graph_loss: 0.3338
Epoch 49/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1363 - accuracy: 0.9773 - graph_loss: 0.3424
Epoch 50/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1511 - accuracy: 0.9694 - graph_loss: 0.3402
Epoch 51/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1366 - accuracy: 0.9759 - graph_loss: 0.3385
Epoch 52/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1254 - accuracy: 0.9777 - graph_loss: 0.3474
Epoch 53/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1289 - accuracy: 0.9740 - graph_loss: 0.3469
Epoch 54/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1410 - accuracy: 0.9689 - graph_loss: 0.3475
Epoch 55/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1356 - accuracy: 0.9703 - graph_loss: 0.3483
Epoch 56/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1283 - accuracy: 0.9773 - graph_loss: 0.3412
Epoch 57/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1264 - accuracy: 0.9745 - graph_loss: 0.3473
Epoch 58/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1242 - accuracy: 0.9740 - graph_loss: 0.3443
Epoch 59/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1144 - accuracy: 0.9782 - graph_loss: 0.3440
Epoch 60/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1250 - accuracy: 0.9735 - graph_loss: 0.3357
Epoch 61/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1190 - accuracy: 0.9787 - graph_loss: 0.3400
Epoch 62/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1141 - accuracy: 0.9814 - graph_loss: 0.3419
Epoch 63/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1085 - accuracy: 0.9787 - graph_loss: 0.3395
Epoch 64/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1148 - accuracy: 0.9768 - graph_loss: 0.3504
Epoch 65/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1137 - accuracy: 0.9791 - graph_loss: 0.3360
Epoch 66/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1121 - accuracy: 0.9745 - graph_loss: 0.3469
Epoch 67/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1046 - accuracy: 0.9810 - graph_loss: 0.3476
Epoch 68/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1112 - accuracy: 0.9791 - graph_loss: 0.3431
Epoch 69/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1075 - accuracy: 0.9787 - graph_loss: 0.3455
Epoch 70/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0986 - accuracy: 0.9875 - graph_loss: 0.3403
Epoch 71/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1141 - accuracy: 0.9782 - graph_loss: 0.3508
Epoch 72/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1012 - accuracy: 0.9814 - graph_loss: 0.3453
Epoch 73/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9833 - graph_loss: 0.3430
Epoch 74/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9842 - graph_loss: 0.3447
Epoch 75/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0988 - accuracy: 0.9842 - graph_loss: 0.3430
Epoch 76/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0915 - accuracy: 0.9856 - graph_loss: 0.3475
Epoch 77/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0960 - accuracy: 0.9833 - graph_loss: 0.3353
Epoch 78/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0916 - accuracy: 0.9838 - graph_loss: 0.3441
Epoch 79/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0979 - accuracy: 0.9800 - graph_loss: 0.3476
Epoch 80/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9782 - graph_loss: 0.3400
Epoch 81/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0978 - accuracy: 0.9838 - graph_loss: 0.3386
Epoch 82/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9805 - graph_loss: 0.3416
Epoch 83/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0957 - accuracy: 0.9838 - graph_loss: 0.3398
Epoch 84/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9879 - graph_loss: 0.3379
Epoch 85/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9838 - graph_loss: 0.3441
Epoch 86/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0906 - accuracy: 0.9847 - graph_loss: 0.3445
Epoch 87/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9852 - graph_loss: 0.3506
Epoch 88/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0821 - accuracy: 0.9898 - graph_loss: 0.3448
Epoch 89/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0803 - accuracy: 0.9865 - graph_loss: 0.3370
Epoch 90/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9828 - graph_loss: 0.3428
Epoch 91/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0887 - accuracy: 0.9852 - graph_loss: 0.3505
Epoch 92/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0882 - accuracy: 0.9847 - graph_loss: 0.3396
Epoch 93/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0807 - accuracy: 0.9879 - graph_loss: 0.3473
Epoch 94/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0820 - accuracy: 0.9861 - graph_loss: 0.3367
Epoch 95/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0864 - accuracy: 0.9838 - graph_loss: 0.3353
Epoch 96/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0786 - accuracy: 0.9889 - graph_loss: 0.3392
Epoch 97/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0735 - accuracy: 0.9912 - graph_loss: 0.3443
Epoch 98/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0861 - accuracy: 0.9842 - graph_loss: 0.3381
Epoch 99/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0850 - accuracy: 0.9833 - graph_loss: 0.3376
Epoch 100/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0841 - accuracy: 0.9879 - graph_loss: 0.3510

<tensorflow.python.keras.callbacks.History at 0x7fc3d853ce10>

グラフの正則化を使用してMLPモデルを評価する

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 6ms/step - loss: 1.2475 - accuracy: 0.8192


Eval accuracy for  MLP + graph regularization :  0.8191681504249573
Eval loss for  MLP + graph regularization :  1.2474583387374878

グラフ正則化モデルの精度は、ベースモデル( base_model )の精度よりも約2〜3%高くなります。

結論

ニューラル構造化学習(NSL)フレームワークを使用して、自然引用グラフ(Cora)のドキュメント分類にグラフの正規化を使用する方法を示しました。 高度なチュートリアルでは、グラフの正則化を使用してニューラルネットワークをトレーニングする前に、サンプルの埋め込みに基づいてグラフを合成します。このアプローチは、入力に明示的なグラフが含まれていない場合に役立ちます。

監視の量を変えたり、グラフの正則化のためにさまざまなニューラルアーキテクチャを試したりして、さらに実験することをお勧めします。