자연 그래프를 사용한, 문서 분류를 위한 그래프 정규화

TensorFlow.org에서 보기 Run in Google Colab View source on GitHub

개요

그래프 정규화는 Neural Graph Learning의 더 넓은 패러다임에서 사용되는 특정 기술입니다(Bui et al., 2018). 핵심 아이디어는 레이블이 지정된 데이터와 레이블이 없는 데이터를 모두 활용하여 그래프 정규화 목표를 갖고 신경망 모델을 훈련하는 것입니다.

이 튜토리얼에서는 그래프 정규화를 사용하여 자연(유기적) 그래프를 형성하는 문서를 분류하는 방법을 살펴봅니다.

Neural Structured Learning(NSL) 프레임워크를 사용하여 그래프 정규화 모델을 생성하는 일반적인 방법은 다음과 같습니다.

  1. 입력 그래프 및 샘플 특성에서 훈련 데이터를 생성합니다. 그래프의 노드는 샘플에 해당하고, 그래프의 간선은 샘플 쌍 간의 유사성에 해당합니다. 결과 훈련 데이터에는 원래 노드 특성 외에도 이웃 특성이 포함됩니다.
  2. Keras 순차, 함수형 또는 서브 클래스 API를 사용하여 신경망을 기본 모델로 만듭니다.
  3. NSL 프레임워크에서 제공하는 GraphRegularization 래퍼 클래스로 기본 모델을 래핑하여 새 그래프 Keras 모델을 만듭니다. 이 새로운 모델은 훈련 목표에서 그래프 정규화 손실을 정규화 항으로 포함합니다.
  4. 그래프 Keras 모델을 훈련하고 평가합니다.

설정

Neural Structured Learning 패키지를 설치합니다.

pip install --quiet neural-structured-learning

종속성 및 가져오기

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version:  2.4.0
Eager mode:  True
GPU is available

Cora 데이터세트

Cora 데이터세트는 노드가 머신러닝 논문을 나타내고 간선이 논문 쌍 간의 인용을 나타내는 인용 그래프입니다. 관련된 작업은 각 논문을 7가지 범주 중 하나로 분류하는 것을 목표로 하는 문서 분류입니다. 즉, 7개의 클래스가 있는 다중 클래스 분류 문제입니다.

그래프

원래 그래프에는 방향이 있습니다. 그러나 이 예에서는 이 그래프의 방향 없는 버전을 고려합니다. 따라서 A 논문이 B 논문을 인용하면 B 논문도 A를 인용한 것으로 간주합니다. 이것이 반드시 사실은 아니지만, 이 예에서는 인용을 유사성에 대한 프록시로 간주하며, 일반적으로 교환 속성입니다.

특성

입력의 각 논문에는 효과적으로 두 가지 특성이 포함되어 있습니다.

  1. Words: 종이에 있는 텍스트를 표현한 밀집 멀티-핫 단어 주머니(bag-of-words)입니다. Cora 데이터세트의 어휘에는 1433개의 고유한 단어가 포함되어 있습니다. 따라서 이 특성의 길이는 1433이고, 위치 'i'의 값은 주어진 논문에서 해당 어휘의 단어 'i'가 존재하는지 여부를 나타내는 0/1입니다.

  2. Label: 논문의 클래스 ID(카테고리)를 나타내는 단일 정수입니다.

Cora 데이터세트 다운로드하기

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

Cora 데이터를 NSL 형식으로 변환하기

Cora 데이터세트를 전처리하고 Neural Structured Learning에 필요한 형식으로 변환하기 위해 NSL github 리포지토리에 포함된 'preprocess_cora_dataset.py' 스크립트를 실행합니다. 이 스크립트는 다음을 수행합니다.

  1. 원래 노드 특성과 그래프를 사용하여 이웃 특성을 생성합니다.
  2. tf.train.Example 인스턴스를 포함하는 훈련 및 테스트 데이터 분할을 생성합니다.
  3. 결과 훈련 및 테스트 데이터를 TFRecord 형식으로 유지합니다.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2021-01-15 02:26:25--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.128.133, 151.101.64.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0.001s  

2021-01-15 02:26:26 (20.7 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2021-01-15 02:26:26.653369: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.01 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.51 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.05 minutes.

전역 변수

훈련 및 테스트 데이터에 대한 파일 경로는 위의 'preprocess_cora_dataset.py' 스크립트를 호출하는 데 사용된 명령 줄 플래그 값을 기반으로 합니다.

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

하이퍼 매개변수

HParams의 인스턴스를 사용하여 훈련 및 평가에 사용되는 다양한 하이퍼 매개변수 및 상수를 포함합니다. 아래에서 각각에 대해 간략하게 설명합니다.

  • num_classes: 총 7개의 클래스가 있습니다.

  • max_seq_length: 어휘의 크기이며, 입력의 모든 인스턴스는 밀집 멀티-핫, 단어 주머니(bag-of-words)의 표현을 갖습니다. 즉, 단어의 값이 1이면 해당 단어가 입력에 있음을 나타내고, 값이 0이면 그렇지 않음을 나타냅니다.

  • distance_type: 샘플을 이웃으로 정규화하는 데 사용되는 거리 메트릭입니다.

  • graph_regularization_multiplier: 전체 손실 함수에서 그래프 정규화 항의 상대적 가중치를 제어합니다.

  • num_neighbors: 그래프 정규화에 사용되는 이웃의 수입니다. 이 값은 preprocess_cora_dataset.py를 실행할 때 위에 사용된 max_nbrs 명령 줄 인수보다 작거나 같아야 합니다.

  • num_fc_units: 신경망에서 완전 연결된 레이어의 수입니다.

  • train_epochs: 훈련 epoch의 수입니다.

  • batch_size: 훈련 및 평가에 사용되는 배치 크기입니다.

  • dropout_rate: 각 완전 연결 레이어의 드롭아웃 비율을 제어합니다.

  • eval_steps: 평가가 완료된 것으로 간주하기 전에 처리할 배치의 수입니다. None으로 설정하면, 테스트세트의 모든 인스턴스가 평가됩니다.

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

훈련 및 테스트 데이터 로드하기

이 노트북의 앞부분에서 설명한 것처럼 입력 훈련 및 테스트 데이터는 'preprocess_cora_dataset.py'에 의해 생성되었습니다. 데이터를 두 개의 tf.data.Dataset 객체로 로드합니다. 하나는 훈련용이고 다른 하나는 테스트용입니다.

모델의 입력 레이어에서 각 샘플의 'words' 및 'label' 특성뿐만 아니라 hparams.num_neighbors 값을 기반으로 해당 이웃 특성도 추출합니다. 이웃이 hparams.num_neighbors보다 적은 인스턴스에는 존재하지 않는 이웃 특성에 대해 더미 값이 할당됩니다.

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

내용을 보기 위해 훈련 데이터세트를 살펴보겠습니다.

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 1 ... 0 0 0]
 [1 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[2 2 2 0 0 6 6 1 2 6 2 0 5 6 1 2 5 3 3 0 3 3 2 6 6 2 1 2 1 3 6 4 6 3 0 2 2
 1 3 3 6 6 3 2 2 1 2 2 6 6 5 0 6 2 0 2 6 6 2 2 2 5 2 3 3 0 3 3 6 3 6 3 1 2
 2 3 3 3 2 3 0 1 2 2 0 2 3 3 2 6 3 2 3 1 2 4 2 1 2 2 3 6 1 2 3 2 5 2 2 3 2
 2 1 1 3 2 1 4 0 2 3 5 2 1 2 2 0 1], shape=(128,), dtype=int64)

내용을 보기 위해 테스트 데이터세트를 살펴보겠습니다.

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

모델 정의

그래프 정규화를 사용하는 방법을 보여주기 위해 먼저 이 문제에 대한 기본 모델을 빌드합니다. 2개의 숨겨진 레이어와 그 사이에 드롭아웃이 있는 간단한 피드 포워드 신경망을 사용합니다. tf.Keras 프레임워크에서 지원하는 모든 모델 유형(순차, 함수형 및 서브 클래스)을 사용하여 기본 모델을 생성하는 방법을 설명합니다.

순차 기본 모델

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
  return model

함수형 기본 모델

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(
      hparams.num_classes, activation='softmax')(
          cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

서브 클래스 기본 모델

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(
          hparams.num_classes, activation='softmax')

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

기본 모델 생성하기

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
words (InputLayer)           [(None, 1433)]            0         
_________________________________________________________________
lambda (Lambda)              (None, 1433)              0         
_________________________________________________________________
dense (Dense)                (None, 50)                71700     
_________________________________________________________________
dropout (Dropout)            (None, 50)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 50)                2550      
_________________________________________________________________
dropout_1 (Dropout)          (None, 50)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 7)                 357       
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

기본 MLP 모델 훈련하기

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:595: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model.
  [n for n in tensors.keys() if n not in ref_input_names])
17/17 [==============================] - 1s 12ms/step - loss: 1.9717 - accuracy: 0.1436
Epoch 2/100
17/17 [==============================] - 0s 12ms/step - loss: 1.8710 - accuracy: 0.2844
Epoch 3/100
17/17 [==============================] - 0s 11ms/step - loss: 1.7676 - accuracy: 0.3435
Epoch 4/100
17/17 [==============================] - 0s 12ms/step - loss: 1.6748 - accuracy: 0.3555
Epoch 5/100
17/17 [==============================] - 0s 11ms/step - loss: 1.5762 - accuracy: 0.3930
Epoch 6/100
17/17 [==============================] - 0s 12ms/step - loss: 1.4101 - accuracy: 0.5144
Epoch 7/100
17/17 [==============================] - 0s 12ms/step - loss: 1.2371 - accuracy: 0.5929
Epoch 8/100
17/17 [==============================] - 0s 11ms/step - loss: 1.1377 - accuracy: 0.6022
Epoch 9/100
17/17 [==============================] - 0s 12ms/step - loss: 0.9814 - accuracy: 0.6762
Epoch 10/100
17/17 [==============================] - 0s 11ms/step - loss: 0.8619 - accuracy: 0.7089
Epoch 11/100
17/17 [==============================] - 0s 12ms/step - loss: 0.8184 - accuracy: 0.7439
Epoch 12/100
17/17 [==============================] - 0s 12ms/step - loss: 0.7374 - accuracy: 0.7569
Epoch 13/100
17/17 [==============================] - 0s 12ms/step - loss: 0.6403 - accuracy: 0.7911
Epoch 14/100
17/17 [==============================] - 0s 12ms/step - loss: 0.6167 - accuracy: 0.8038
Epoch 15/100
17/17 [==============================] - 0s 11ms/step - loss: 0.5324 - accuracy: 0.8423
Epoch 16/100
17/17 [==============================] - 0s 11ms/step - loss: 0.4927 - accuracy: 0.8447
Epoch 17/100
17/17 [==============================] - 0s 12ms/step - loss: 0.4550 - accuracy: 0.8589
Epoch 18/100
17/17 [==============================] - 0s 12ms/step - loss: 0.4416 - accuracy: 0.8682
Epoch 19/100
17/17 [==============================] - 0s 12ms/step - loss: 0.3883 - accuracy: 0.8835
Epoch 20/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3845 - accuracy: 0.8679
Epoch 21/100
17/17 [==============================] - 0s 12ms/step - loss: 0.3481 - accuracy: 0.8922
Epoch 22/100
17/17 [==============================] - 0s 12ms/step - loss: 0.3229 - accuracy: 0.8996
Epoch 23/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2873 - accuracy: 0.9198
Epoch 24/100
17/17 [==============================] - 0s 12ms/step - loss: 0.2848 - accuracy: 0.9158
Epoch 25/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2820 - accuracy: 0.9076
Epoch 26/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2745 - accuracy: 0.9187
Epoch 27/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2591 - accuracy: 0.9247
Epoch 28/100
17/17 [==============================] - 0s 12ms/step - loss: 0.2313 - accuracy: 0.9377
Epoch 29/100
17/17 [==============================] - 0s 12ms/step - loss: 0.2242 - accuracy: 0.9371
Epoch 30/100
17/17 [==============================] - 0s 12ms/step - loss: 0.2171 - accuracy: 0.9358
Epoch 31/100
17/17 [==============================] - 0s 12ms/step - loss: 0.2285 - accuracy: 0.9365
Epoch 32/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2079 - accuracy: 0.9358
Epoch 33/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1881 - accuracy: 0.9430
Epoch 34/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1703 - accuracy: 0.9556
Epoch 35/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1751 - accuracy: 0.9464
Epoch 36/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1843 - accuracy: 0.9495
Epoch 37/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1580 - accuracy: 0.9588
Epoch 38/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1557 - accuracy: 0.9548
Epoch 39/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1647 - accuracy: 0.9548
Epoch 40/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1494 - accuracy: 0.9584
Epoch 41/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1299 - accuracy: 0.9665
Epoch 42/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1432 - accuracy: 0.9657
Epoch 43/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1293 - accuracy: 0.9613
Epoch 44/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1050 - accuracy: 0.9759
Epoch 45/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1292 - accuracy: 0.9569
Epoch 46/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1182 - accuracy: 0.9670
Epoch 47/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1220 - accuracy: 0.9626
Epoch 48/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1210 - accuracy: 0.9598
Epoch 49/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1015 - accuracy: 0.9733
Epoch 50/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1042 - accuracy: 0.9714
Epoch 51/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1079 - accuracy: 0.9707
Epoch 52/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1176 - accuracy: 0.9615
Epoch 53/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0952 - accuracy: 0.9727
Epoch 54/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1062 - accuracy: 0.9697
Epoch 55/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0891 - accuracy: 0.9743
Epoch 56/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0906 - accuracy: 0.9764
Epoch 57/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0931 - accuracy: 0.9707
Epoch 58/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0901 - accuracy: 0.9762
Epoch 59/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0759 - accuracy: 0.9794
Epoch 60/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0954 - accuracy: 0.9700
Epoch 61/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0933 - accuracy: 0.9769
Epoch 62/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0798 - accuracy: 0.9783
Epoch 63/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0695 - accuracy: 0.9845
Epoch 64/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0670 - accuracy: 0.9822
Epoch 65/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0751 - accuracy: 0.9807
Epoch 66/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0743 - accuracy: 0.9781
Epoch 67/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0619 - accuracy: 0.9855
Epoch 68/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0683 - accuracy: 0.9820
Epoch 69/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0669 - accuracy: 0.9822
Epoch 70/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0658 - accuracy: 0.9830
Epoch 71/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0681 - accuracy: 0.9841
Epoch 72/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0857 - accuracy: 0.9760
Epoch 73/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0701 - accuracy: 0.9767
Epoch 74/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0820 - accuracy: 0.9799
Epoch 75/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0612 - accuracy: 0.9854
Epoch 76/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0662 - accuracy: 0.9804
Epoch 77/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0678 - accuracy: 0.9800
Epoch 78/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0590 - accuracy: 0.9838
Epoch 79/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0640 - accuracy: 0.9807
Epoch 80/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0595 - accuracy: 0.9846
Epoch 81/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0556 - accuracy: 0.9824
Epoch 82/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0632 - accuracy: 0.9816
Epoch 83/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0637 - accuracy: 0.9822
Epoch 84/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0524 - accuracy: 0.9858
Epoch 85/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0665 - accuracy: 0.9780
Epoch 86/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0586 - accuracy: 0.9807
Epoch 87/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0542 - accuracy: 0.9844
Epoch 88/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0506 - accuracy: 0.9847
Epoch 89/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0447 - accuracy: 0.9866
Epoch 90/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0465 - accuracy: 0.9880
Epoch 91/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0651 - accuracy: 0.9754
Epoch 92/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0584 - accuracy: 0.9825
Epoch 93/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0479 - accuracy: 0.9889
Epoch 94/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0517 - accuracy: 0.9843
Epoch 95/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0566 - accuracy: 0.9835
Epoch 96/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0478 - accuracy: 0.9844
Epoch 97/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0511 - accuracy: 0.9809
Epoch 98/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0516 - accuracy: 0.9866
Epoch 99/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0454 - accuracy: 0.9891
Epoch 100/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0530 - accuracy: 0.9822
<tensorflow.python.keras.callbacks.History at 0x7fc82c426c88>

기본 MLP 모델 평가하기

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 8ms/step - loss: 1.4110 - accuracy: 0.7866


Eval accuracy for  Base MLP model :  0.7866184711456299
Eval loss for  Base MLP model :  1.4110491275787354

그래프 정규화로 MLP 모델 훈련하기

그래프 정규화를 기존 tf.Keras.Model의 손실 항에 통합하려면 몇 줄의 코드만 있으면 됩니다. 기본 모델은 래핑되어 새로운 tf.Keras 서브 클래스 모델을 생성하며, 손실에는 그래프 정규화가 포함됩니다.

그래프 정규화의 점진적 이점을 평가하기 위해 새 기본 모델 인스턴스를 생성합니다. 이는 base_model이 이미 몇 번의 반복 동안 훈련되었으며, 이 훈련된 모델을 재사용하여 그래프 정규화 모델을 만드는 것은 base_model에 대한 공정한 비교가 되지 않기 때문입니다.

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:437: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "shape. This may consume a large amount of memory." % value)
17/17 [==============================] - 2s 11ms/step - loss: 1.9225 - accuracy: 0.1801 - scaled_graph_loss: 9.2100e-04
Epoch 2/100
17/17 [==============================] - 0s 11ms/step - loss: 1.8425 - accuracy: 0.2831 - scaled_graph_loss: 0.0014
Epoch 3/100
17/17 [==============================] - 0s 11ms/step - loss: 1.7455 - accuracy: 0.3194 - scaled_graph_loss: 0.0025
Epoch 4/100
17/17 [==============================] - 0s 11ms/step - loss: 1.6523 - accuracy: 0.3600 - scaled_graph_loss: 0.0042
Epoch 5/100
17/17 [==============================] - 0s 11ms/step - loss: 1.5537 - accuracy: 0.3918 - scaled_graph_loss: 0.0061
Epoch 6/100
17/17 [==============================] - 0s 12ms/step - loss: 1.3923 - accuracy: 0.4934 - scaled_graph_loss: 0.0091
Epoch 7/100
17/17 [==============================] - 0s 11ms/step - loss: 1.2615 - accuracy: 0.5619 - scaled_graph_loss: 0.0131
Epoch 8/100
17/17 [==============================] - 0s 12ms/step - loss: 1.1398 - accuracy: 0.6262 - scaled_graph_loss: 0.0167
Epoch 9/100
17/17 [==============================] - 0s 11ms/step - loss: 1.0197 - accuracy: 0.6717 - scaled_graph_loss: 0.0212
Epoch 10/100
17/17 [==============================] - 0s 11ms/step - loss: 0.9155 - accuracy: 0.7002 - scaled_graph_loss: 0.0247
Epoch 11/100
17/17 [==============================] - 0s 11ms/step - loss: 0.7946 - accuracy: 0.7688 - scaled_graph_loss: 0.0258
Epoch 12/100
17/17 [==============================] - 0s 11ms/step - loss: 0.7516 - accuracy: 0.7877 - scaled_graph_loss: 0.0272
Epoch 13/100
17/17 [==============================] - 0s 12ms/step - loss: 0.6715 - accuracy: 0.8285 - scaled_graph_loss: 0.0276
Epoch 14/100
17/17 [==============================] - 0s 11ms/step - loss: 0.6117 - accuracy: 0.8109 - scaled_graph_loss: 0.0317
Epoch 15/100
17/17 [==============================] - 0s 11ms/step - loss: 0.5729 - accuracy: 0.8365 - scaled_graph_loss: 0.0313
Epoch 16/100
17/17 [==============================] - 0s 12ms/step - loss: 0.5208 - accuracy: 0.8520 - scaled_graph_loss: 0.0327
Epoch 17/100
17/17 [==============================] - 0s 12ms/step - loss: 0.4611 - accuracy: 0.8802 - scaled_graph_loss: 0.0307
Epoch 18/100
17/17 [==============================] - 0s 11ms/step - loss: 0.4573 - accuracy: 0.8776 - scaled_graph_loss: 0.0324
Epoch 19/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3964 - accuracy: 0.9074 - scaled_graph_loss: 0.0320
Epoch 20/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3887 - accuracy: 0.9051 - scaled_graph_loss: 0.0337
Epoch 21/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3882 - accuracy: 0.8998 - scaled_graph_loss: 0.0350
Epoch 22/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3457 - accuracy: 0.9086 - scaled_graph_loss: 0.0337
Epoch 23/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3666 - accuracy: 0.9020 - scaled_graph_loss: 0.0332
Epoch 24/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3249 - accuracy: 0.9178 - scaled_graph_loss: 0.0336
Epoch 25/100
17/17 [==============================] - 0s 11ms/step - loss: 0.3070 - accuracy: 0.9131 - scaled_graph_loss: 0.0348
Epoch 26/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2703 - accuracy: 0.9342 - scaled_graph_loss: 0.0323
Epoch 27/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2743 - accuracy: 0.9369 - scaled_graph_loss: 0.0346
Epoch 28/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2609 - accuracy: 0.9313 - scaled_graph_loss: 0.0334
Epoch 29/100
17/17 [==============================] - 0s 12ms/step - loss: 0.2561 - accuracy: 0.9355 - scaled_graph_loss: 0.0334
Epoch 30/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2222 - accuracy: 0.9489 - scaled_graph_loss: 0.0318
Epoch 31/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2030 - accuracy: 0.9545 - scaled_graph_loss: 0.0324
Epoch 32/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2269 - accuracy: 0.9437 - scaled_graph_loss: 0.0329
Epoch 33/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2138 - accuracy: 0.9498 - scaled_graph_loss: 0.0351
Epoch 34/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2171 - accuracy: 0.9490 - scaled_graph_loss: 0.0347
Epoch 35/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2096 - accuracy: 0.9519 - scaled_graph_loss: 0.0344
Epoch 36/100
17/17 [==============================] - 0s 11ms/step - loss: 0.2035 - accuracy: 0.9517 - scaled_graph_loss: 0.0350
Epoch 37/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1795 - accuracy: 0.9619 - scaled_graph_loss: 0.0330
Epoch 38/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1818 - accuracy: 0.9603 - scaled_graph_loss: 0.0346
Epoch 39/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1778 - accuracy: 0.9596 - scaled_graph_loss: 0.0340
Epoch 40/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1788 - accuracy: 0.9597 - scaled_graph_loss: 0.0348
Epoch 41/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1604 - accuracy: 0.9699 - scaled_graph_loss: 0.0332
Epoch 42/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1692 - accuracy: 0.9682 - scaled_graph_loss: 0.0357
Epoch 43/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1506 - accuracy: 0.9744 - scaled_graph_loss: 0.0342
Epoch 44/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1763 - accuracy: 0.9628 - scaled_graph_loss: 0.0352
Epoch 45/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1721 - accuracy: 0.9657 - scaled_graph_loss: 0.0354
Epoch 46/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1386 - accuracy: 0.9726 - scaled_graph_loss: 0.0325
Epoch 47/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1458 - accuracy: 0.9669 - scaled_graph_loss: 0.0332
Epoch 48/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1211 - accuracy: 0.9815 - scaled_graph_loss: 0.0334
Epoch 49/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1281 - accuracy: 0.9786 - scaled_graph_loss: 0.0326
Epoch 50/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1285 - accuracy: 0.9814 - scaled_graph_loss: 0.0343
Epoch 51/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1317 - accuracy: 0.9748 - scaled_graph_loss: 0.0355
Epoch 52/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1420 - accuracy: 0.9706 - scaled_graph_loss: 0.0343
Epoch 53/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1395 - accuracy: 0.9715 - scaled_graph_loss: 0.0338
Epoch 54/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1260 - accuracy: 0.9758 - scaled_graph_loss: 0.0350
Epoch 55/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1261 - accuracy: 0.9778 - scaled_graph_loss: 0.0321
Epoch 56/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1205 - accuracy: 0.9793 - scaled_graph_loss: 0.0341
Epoch 57/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1170 - accuracy: 0.9814 - scaled_graph_loss: 0.0337
Epoch 58/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1223 - accuracy: 0.9715 - scaled_graph_loss: 0.0338
Epoch 59/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1181 - accuracy: 0.9737 - scaled_graph_loss: 0.0332
Epoch 60/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1104 - accuracy: 0.9827 - scaled_graph_loss: 0.0341
Epoch 61/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0919 - accuracy: 0.9840 - scaled_graph_loss: 0.0339
Epoch 62/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0999 - accuracy: 0.9838 - scaled_graph_loss: 0.0331
Epoch 63/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1078 - accuracy: 0.9833 - scaled_graph_loss: 0.0339
Epoch 64/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0994 - accuracy: 0.9854 - scaled_graph_loss: 0.0324
Epoch 65/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1016 - accuracy: 0.9820 - scaled_graph_loss: 0.0355
Epoch 66/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0962 - accuracy: 0.9859 - scaled_graph_loss: 0.0327
Epoch 67/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0999 - accuracy: 0.9846 - scaled_graph_loss: 0.0345
Epoch 68/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1032 - accuracy: 0.9823 - scaled_graph_loss: 0.0333
Epoch 69/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1035 - accuracy: 0.9828 - scaled_graph_loss: 0.0349
Epoch 70/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1052 - accuracy: 0.9828 - scaled_graph_loss: 0.0344
Epoch 71/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0921 - accuracy: 0.9874 - scaled_graph_loss: 0.0329
Epoch 72/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0907 - accuracy: 0.9867 - scaled_graph_loss: 0.0344
Epoch 73/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0982 - accuracy: 0.9851 - scaled_graph_loss: 0.0344
Epoch 74/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0892 - accuracy: 0.9817 - scaled_graph_loss: 0.0319
Epoch 75/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0966 - accuracy: 0.9835 - scaled_graph_loss: 0.0345
Epoch 76/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0888 - accuracy: 0.9888 - scaled_graph_loss: 0.0339
Epoch 77/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1019 - accuracy: 0.9800 - scaled_graph_loss: 0.0330
Epoch 78/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0931 - accuracy: 0.9806 - scaled_graph_loss: 0.0334
Epoch 79/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0862 - accuracy: 0.9848 - scaled_graph_loss: 0.0351
Epoch 80/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0852 - accuracy: 0.9870 - scaled_graph_loss: 0.0321
Epoch 81/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0927 - accuracy: 0.9837 - scaled_graph_loss: 0.0344
Epoch 82/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0954 - accuracy: 0.9826 - scaled_graph_loss: 0.0370
Epoch 83/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0810 - accuracy: 0.9904 - scaled_graph_loss: 0.0333
Epoch 84/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0827 - accuracy: 0.9874 - scaled_graph_loss: 0.0304
Epoch 85/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0898 - accuracy: 0.9854 - scaled_graph_loss: 0.0330
Epoch 86/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0829 - accuracy: 0.9864 - scaled_graph_loss: 0.0332
Epoch 87/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0784 - accuracy: 0.9893 - scaled_graph_loss: 0.0336
Epoch 88/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0811 - accuracy: 0.9876 - scaled_graph_loss: 0.0321
Epoch 89/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0811 - accuracy: 0.9887 - scaled_graph_loss: 0.0327
Epoch 90/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0813 - accuracy: 0.9856 - scaled_graph_loss: 0.0342
Epoch 91/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0765 - accuracy: 0.9896 - scaled_graph_loss: 0.0333
Epoch 92/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0888 - accuracy: 0.9814 - scaled_graph_loss: 0.0342
Epoch 93/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0796 - accuracy: 0.9843 - scaled_graph_loss: 0.0329
Epoch 94/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0784 - accuracy: 0.9859 - scaled_graph_loss: 0.0333
Epoch 95/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0842 - accuracy: 0.9882 - scaled_graph_loss: 0.0332
Epoch 96/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0739 - accuracy: 0.9920 - scaled_graph_loss: 0.0337
Epoch 97/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0810 - accuracy: 0.9857 - scaled_graph_loss: 0.0347
Epoch 98/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0856 - accuracy: 0.9871 - scaled_graph_loss: 0.0356
Epoch 99/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0810 - accuracy: 0.9879 - scaled_graph_loss: 0.0305
Epoch 100/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0735 - accuracy: 0.9882 - scaled_graph_loss: 0.0339
<tensorflow.python.keras.callbacks.History at 0x7fc8206484a8>

그래프 정규화로 MLP 모델 평가하기

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 8ms/step - loss: 1.3513 - accuracy: 0.7946


Eval accuracy for  MLP + graph regularization :  0.8010849952697754
Eval loss for  MLP + graph regularization :  1.349246859550476

그래프 정규화 모델의 정확성은 기본 모델(base_model)보다 약 2~3% 높습니다.

결론

Neural Structured Learning(NSL) 프레임워크를 사용하여 자연 인용 그래프(Cora)에서 문서 분류를 위해 그래프 정규화를 사용하는 방법을 시연했습니다. 고급 튜토리얼에는 그래프 정규화로 신경망을 훈련하기 전에 샘플 임베딩을 기반으로 그래프를 합성하는 것이 포함됩니다. 이 접근 방식은 입력에 명시적 그래프가 포함되지 않은 경우 유용합니다.

사용자가 감독의 양을 변경하고 그래프 정규화를 위해 다양한 신경 아키텍처를 시도하여 추가 실험을 할 것을 권장합니다.