Регуляризация графов для классификации документов с использованием естественных графов

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Обзор

График регуляризации является метод , специфический в рамках более широкой парадигмы Neural График обучения ( Bui и др., 2018 ). Основная идея состоит в том, чтобы обучать модели нейронных сетей с целью, упорядоченной по графу, используя как помеченные, так и немеченые данные.

В этом руководстве мы рассмотрим использование регуляризации графа для классификации документов, которые образуют естественный (органический) граф.

Общий рецепт создания модели с регуляризацией графа с использованием структуры нейронного структурированного обучения (NSL) выглядит следующим образом:

  1. Сгенерируйте данные обучения из входного графика и примеров функций. Узлы в графе соответствуют выборкам, а ребра в графе соответствуют сходству между парами выборок. Результирующие обучающие данные будут содержать соседние функции в дополнение к исходным характеристикам узлов.
  2. Создание нейронной сети в качестве базовой модели с использованием Keras последовательного, функционального или подкласса API.
  3. Заверните базовую модель с GraphRegularization класса - оболочки, которая предоставляется в рамках NSL, чтобы создать новый график Keras модель. Эта новая модель будет включать потерю регуляризации графа в качестве члена регуляризации в цели обучения.
  4. Поезд и оценить график Keras модели.

Настраивать

Установите пакет нейронного структурированного обучения.

pip install --quiet neural-structured-learning

Зависимости и импорт

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version:  2.8.0-rc0
Eager mode:  True
GPU is NOT AVAILABLE
2022-01-05 12:39:27.704660: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Набор данных Cora

Набор данные Коры цитаты граф , где узлы представляют машинное обучение бумагу и ребра представляют собой цитату между парами документов. Задача заключается в классификации документов, цель которой состоит в том, чтобы отнести каждую статью к одной из 7 категорий. Другими словами, это задача классификации нескольких классов с 7 классами.

График

Исходный граф ориентирован. Однако для целей этого примера мы рассматриваем неориентированную версию этого графа. Итак, если в статье A цитируется статья B, мы также считаем, что статья B процитировала A. Хотя это не обязательно так, в этом примере мы рассматриваем цитаты как показатель сходства, которое обычно является коммутативным свойством.

Функции

Каждая статья во входных данных фактически содержит 2 функции:

  1. Слова: Плотные, мульти-горячий мешок из-слов представления текста в документе. Словарь для набора данных Cora содержит 1433 уникальных слова. Таким образом, длина этого признака равна 1433, а значение в позиции «i» равно 0/1, указывая, существует ли слово «i» в словаре в данной статье или нет.

  2. Метка: одно целое число , представляющее идентификатор класса (категории) бумаги.

Загрузите набор данных Cora

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

Преобразование данных Cora в формат NSL

Для предварительной обработки набора данных Cora и преобразовать его в формат , требуемый Neural Structured обучения, мы будем запускать «preprocess_cora_dataset.py» сценарий, который включен в хранилище NSL GitHub. Этот сценарий делает следующее:

  1. Сгенерируйте соседние объекты, используя исходные узловые объекты и график.
  2. Сформировать поезд и тестовые данные , содержащий шпагат tf.train.Example экземпляров.
  3. Упорство в результате поезд и тестовые данные в TFRecord формате.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2022-01-05 12:39:28--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0s      

2022-01-05 12:39:28 (78.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2022-01-05 12:39:31.378912: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.01 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.36 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.04 minutes.

Глобальные переменные

Файловые пути к данным поездов и испытаний, на основе значений флага командной строки , используемых для вызова «» preprocess_cora_dataset.py скрипта выше.

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

Гиперпараметры

Мы будем использовать экземпляр HParams включать различные гиперпараметры и константы , используемые для обучения и оценки. Кратко опишем каждый из них ниже:

  • num_classes: Есть в общей сложности 7 различных классов

  • max_seq_length: Это размер словаря и все экземпляры в входе имеет плотное мульти-горячие, мешок из-слов представления. Другими словами, значение 1 для слова указывает, что слово присутствует во входных данных, а значение 0 указывает, что это не так.

  • distance_type: Это расстояние показателя , используемое для упорядочения образца с его соседями.

  • graph_regularization_multiplier: Это контролирует относительный вес термина график регуляризации в общей функции потерь.

  • num_neighbors: Число соседей , используемых для графа регуляризации. Это значение должно быть меньше или равна max_nbrs аргумент командной строки , используемый выше , при запуске preprocess_cora_dataset.py .

  • num_fc_units: Число полностью связанных слоев в нашей нейронной сети.

  • train_epochs: Количество тренировочных эпох.

  • Размер Пакетное используется для обучения и оценки: batch_size.

  • dropout_rate: контролирует скорость отсева после каждого слоя полностью подключенного

  • eval_steps: Количество партий в процесс , прежде чем оценка сочтя завершено. Если установлено значение None , все экземпляры в тестовом наборе оцениваются.

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

Загрузить данные о поездах и тестах

Как было описано ранее в этом ноутбуке, то входные данные обучения и тестирования были созданы 'preprocess_cora_dataset.py. Мы будем загружать их в два tf.data.Dataset объекты - один для поезда и один для испытания.

Во входном слое нашей модели, мы будем добывать не только в «слова» и «метка» особенности каждого образца, но и соответствующий сосед функций , основанных на hparams.num_neighbors стоимости. Случаи с меньшим количеством соседей , чем hparams.num_neighbors будет назначены фиктивными значениями для этих несуществующих соседних функций.

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

Давайте заглянем в набор данных поезда, чтобы взглянуть на его содержимое.

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[2 2 6 2 0 6 1 3 5 0 1 2 3 6 1 1 0 3 5 2 3 1 4 1 6 1 3 2 2 2 0 3 2 1 3 3 2
 3 3 2 3 2 2 0 2 2 6 0 2 1 1 0 5 2 1 4 2 1 2 4 0 2 5 4 3 6 3 2 1 6 2 4 2 2
 6 4 6 4 3 5 2 2 2 4 2 2 2 1 2 2 2 4 2 3 6 2 0 6 6 0 2 6 2 1 2 0 1 1 3 2 0
 2 0 2 1 1 3 5 2 1 2 5 1 6 2 4 6 4], shape=(128,), dtype=int64)

Давайте заглянем в тестовый набор данных, чтобы взглянуть на его содержимое.

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

Определение модели

Чтобы продемонстрировать использование регуляризации графа, мы сначала построим базовую модель для этой проблемы. Мы будем использовать простую нейронную сеть с прямой связью с двумя скрытыми слоями и промежуточными слоями. Проиллюстрируем создание базовой модели с использованием всех типов моделей , поддерживаемых tf.Keras рамок - последовательный, функциональный и подкласса.

Последовательная базовая модель

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes))
  return model

Функциональная базовая модель

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

Базовая модель подкласса

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(hparams.num_classes)

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

Создать базовую модель (ы)

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 words (InputLayer)          [(None, 1433)]            0         
                                                                 
 lambda (Lambda)             (None, 1433)              0         
                                                                 
 dense (Dense)               (None, 50)                71700     
                                                                 
 dropout (Dropout)           (None, 50)                0         
                                                                 
 dense_1 (Dense)             (None, 50)                2550      
                                                                 
 dropout_1 (Dropout)         (None, 50)                0         
                                                                 
 dense_2 (Dense)             (None, 7)                 357       
                                                                 
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

Базовая модель поезда MLP

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)
17/17 [==============================] - 1s 18ms/step - loss: 1.9521 - accuracy: 0.1838
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8590 - accuracy: 0.3044
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7770 - accuracy: 0.3601
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6655 - accuracy: 0.3898
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5386 - accuracy: 0.4543
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3856 - accuracy: 0.5077
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2736 - accuracy: 0.5531
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1636 - accuracy: 0.5889
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0654 - accuracy: 0.6385
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9703 - accuracy: 0.6761
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8689 - accuracy: 0.7104
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7704 - accuracy: 0.7494
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7157 - accuracy: 0.7810
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 0.6296 - accuracy: 0.8186
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5932 - accuracy: 0.8167
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5526 - accuracy: 0.8464
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5112 - accuracy: 0.8445
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8613
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4163 - accuracy: 0.8696
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3808 - accuracy: 0.8849
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3564 - accuracy: 0.8933
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3453 - accuracy: 0.9002
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3226 - accuracy: 0.9114
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3058 - accuracy: 0.9151
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2798 - accuracy: 0.9146
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2638 - accuracy: 0.9248
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2538 - accuracy: 0.9290
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2356 - accuracy: 0.9411
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2080 - accuracy: 0.9425
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2172 - accuracy: 0.9364
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2259 - accuracy: 0.9225
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1944 - accuracy: 0.9480
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1892 - accuracy: 0.9434
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1718 - accuracy: 0.9592
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1826 - accuracy: 0.9508
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1585 - accuracy: 0.9559
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1605 - accuracy: 0.9545
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1529 - accuracy: 0.9550
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1411 - accuracy: 0.9615
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1366 - accuracy: 0.9624
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1431 - accuracy: 0.9578
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1241 - accuracy: 0.9619
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1310 - accuracy: 0.9661
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1284 - accuracy: 0.9652
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1215 - accuracy: 0.9633
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1130 - accuracy: 0.9722
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1074 - accuracy: 0.9722
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1143 - accuracy: 0.9694
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1015 - accuracy: 0.9740
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1077 - accuracy: 0.9698
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1035 - accuracy: 0.9684
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9694
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1000 - accuracy: 0.9689
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0967 - accuracy: 0.9749
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0994 - accuracy: 0.9703
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0943 - accuracy: 0.9740
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0923 - accuracy: 0.9735
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0848 - accuracy: 0.9800
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0836 - accuracy: 0.9782
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0913 - accuracy: 0.9735
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0823 - accuracy: 0.9773
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0753 - accuracy: 0.9810
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0746 - accuracy: 0.9777
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0861 - accuracy: 0.9731
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0765 - accuracy: 0.9787
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0750 - accuracy: 0.9791
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0725 - accuracy: 0.9814
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 0.9791
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0645 - accuracy: 0.9842
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0606 - accuracy: 0.9861
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0775 - accuracy: 0.9805
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0655 - accuracy: 0.9800
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0629 - accuracy: 0.9833
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0625 - accuracy: 0.9824
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0607 - accuracy: 0.9838
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0578 - accuracy: 0.9824
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0568 - accuracy: 0.9842
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0595 - accuracy: 0.9833
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0615 - accuracy: 0.9842
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0555 - accuracy: 0.9852
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0517 - accuracy: 0.9870
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0541 - accuracy: 0.9856
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9884
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0509 - accuracy: 0.9838
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0600 - accuracy: 0.9828
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0617 - accuracy: 0.9800
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0599 - accuracy: 0.9800
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0502 - accuracy: 0.9870
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0416 - accuracy: 0.9907
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0542 - accuracy: 0.9842
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0490 - accuracy: 0.9847
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0374 - accuracy: 0.9916
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9893
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 0.9879
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0543 - accuracy: 0.9861
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0420 - accuracy: 0.9870
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0461 - accuracy: 0.9861
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0425 - accuracy: 0.9898
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0406 - accuracy: 0.9907
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0486 - accuracy: 0.9847
<keras.callbacks.History at 0x7f6f9d5eacd0>

Оценить базовую модель MLP

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4192 - accuracy: 0.7939


Eval accuracy for  Base MLP model :  0.7938517332077026
Eval loss for  Base MLP model :  1.4192423820495605

Обучите модель MLP с регуляризацией графа

Включение графа упорядочению в перспективе потери существующего tf.Keras.Model требует всего несколько строк кода. Базовая модель обернута , чтобы создать новый tf.Keras подкласса модели, чья потеря включает в себя график регуляризации.

Чтобы оценить дополнительные преимущества регуляризации графа, мы создадим новый экземпляр базовой модели. Это происходит потому , что base_model уже подготовлены несколько итераций, и повторное использование этой обученную модель , чтобы создать граф-регуляризованная модель не будет сравнение справедливо для base_model .

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "shape. This may consume a large amount of memory." % value)
17/17 [==============================] - 2s 4ms/step - loss: 1.9798 - accuracy: 0.1601 - scaled_graph_loss: 0.0373
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.9024 - accuracy: 0.2979 - scaled_graph_loss: 0.0254
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8623 - accuracy: 0.3160 - scaled_graph_loss: 0.0317
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8042 - accuracy: 0.3443 - scaled_graph_loss: 0.0498
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7552 - accuracy: 0.3582 - scaled_graph_loss: 0.0696
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7012 - accuracy: 0.4084 - scaled_graph_loss: 0.0866
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6578 - accuracy: 0.4515 - scaled_graph_loss: 0.1114
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6058 - accuracy: 0.5039 - scaled_graph_loss: 0.1300
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5498 - accuracy: 0.5434 - scaled_graph_loss: 0.1508
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5098 - accuracy: 0.6019 - scaled_graph_loss: 0.1651
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4746 - accuracy: 0.6302 - scaled_graph_loss: 0.1844
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4315 - accuracy: 0.6520 - scaled_graph_loss: 0.1917
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3932 - accuracy: 0.6770 - scaled_graph_loss: 0.2024
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3645 - accuracy: 0.7183 - scaled_graph_loss: 0.2145
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3265 - accuracy: 0.7369 - scaled_graph_loss: 0.2324
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3045 - accuracy: 0.7555 - scaled_graph_loss: 0.2358
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2836 - accuracy: 0.7652 - scaled_graph_loss: 0.2404
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2456 - accuracy: 0.7898 - scaled_graph_loss: 0.2469
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2348 - accuracy: 0.8074 - scaled_graph_loss: 0.2615
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2000 - accuracy: 0.8074 - scaled_graph_loss: 0.2542
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1994 - accuracy: 0.8260 - scaled_graph_loss: 0.2729
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1825 - accuracy: 0.8269 - scaled_graph_loss: 0.2676
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1598 - accuracy: 0.8455 - scaled_graph_loss: 0.2742
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8534 - scaled_graph_loss: 0.2797
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1456 - accuracy: 0.8552 - scaled_graph_loss: 0.2714
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8566 - scaled_graph_loss: 0.2796
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1150 - accuracy: 0.8687 - scaled_graph_loss: 0.2850
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8626 - scaled_graph_loss: 0.2772
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0806 - accuracy: 0.8733 - scaled_graph_loss: 0.2756
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0828 - accuracy: 0.8626 - scaled_graph_loss: 0.2907
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0724 - accuracy: 0.8886 - scaled_graph_loss: 0.2834
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0589 - accuracy: 0.8826 - scaled_graph_loss: 0.2881
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0490 - accuracy: 0.8872 - scaled_graph_loss: 0.2972
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0550 - accuracy: 0.8923 - scaled_graph_loss: 0.2935
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0397 - accuracy: 0.8840 - scaled_graph_loss: 0.2795
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0360 - accuracy: 0.8891 - scaled_graph_loss: 0.2966
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0235 - accuracy: 0.8961 - scaled_graph_loss: 0.2890
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0219 - accuracy: 0.8984 - scaled_graph_loss: 0.2965
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0168 - accuracy: 0.9044 - scaled_graph_loss: 0.3023
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0148 - accuracy: 0.9035 - scaled_graph_loss: 0.2984
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9118 - scaled_graph_loss: 0.2888
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0019 - accuracy: 0.9021 - scaled_graph_loss: 0.2877
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9049 - scaled_graph_loss: 0.2912
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9986 - accuracy: 0.9026 - scaled_graph_loss: 0.3040
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9939 - accuracy: 0.9067 - scaled_graph_loss: 0.3016
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9828 - accuracy: 0.9058 - scaled_graph_loss: 0.2877
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9629 - accuracy: 0.9137 - scaled_graph_loss: 0.2844
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9645 - accuracy: 0.9146 - scaled_graph_loss: 0.2933
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9752 - accuracy: 0.9165 - scaled_graph_loss: 0.3013
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9552 - accuracy: 0.9179 - scaled_graph_loss: 0.2865
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9539 - accuracy: 0.9193 - scaled_graph_loss: 0.3044
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9443 - accuracy: 0.9183 - scaled_graph_loss: 0.3010
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9559 - accuracy: 0.9244 - scaled_graph_loss: 0.2987
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9497 - accuracy: 0.9225 - scaled_graph_loss: 0.2979
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9674 - accuracy: 0.9183 - scaled_graph_loss: 0.3034
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9537 - accuracy: 0.9174 - scaled_graph_loss: 0.2834
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9341 - accuracy: 0.9188 - scaled_graph_loss: 0.2939
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9392 - accuracy: 0.9225 - scaled_graph_loss: 0.2998
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9240 - accuracy: 0.9313 - scaled_graph_loss: 0.3022
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9368 - accuracy: 0.9267 - scaled_graph_loss: 0.2979
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9306 - accuracy: 0.9234 - scaled_graph_loss: 0.2952
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9197 - accuracy: 0.9230 - scaled_graph_loss: 0.2916
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9360 - accuracy: 0.9206 - scaled_graph_loss: 0.2947
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9181 - accuracy: 0.9299 - scaled_graph_loss: 0.2996
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9105 - accuracy: 0.9341 - scaled_graph_loss: 0.2981
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9014 - accuracy: 0.9323 - scaled_graph_loss: 0.2897
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9059 - accuracy: 0.9364 - scaled_graph_loss: 0.3083
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9053 - accuracy: 0.9309 - scaled_graph_loss: 0.2976
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9099 - accuracy: 0.9258 - scaled_graph_loss: 0.3069
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9025 - accuracy: 0.9355 - scaled_graph_loss: 0.2890
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8849 - accuracy: 0.9281 - scaled_graph_loss: 0.2933
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8959 - accuracy: 0.9323 - scaled_graph_loss: 0.2918
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9074 - accuracy: 0.9248 - scaled_graph_loss: 0.3065
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8845 - accuracy: 0.9369 - scaled_graph_loss: 0.2874
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8873 - accuracy: 0.9401 - scaled_graph_loss: 0.2996
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8942 - accuracy: 0.9327 - scaled_graph_loss: 0.3086
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9052 - accuracy: 0.9253 - scaled_graph_loss: 0.2986
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8811 - accuracy: 0.9336 - scaled_graph_loss: 0.2948
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8896 - accuracy: 0.9276 - scaled_graph_loss: 0.2919
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8853 - accuracy: 0.9313 - scaled_graph_loss: 0.2944
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8875 - accuracy: 0.9323 - scaled_graph_loss: 0.2925
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8639 - accuracy: 0.9323 - scaled_graph_loss: 0.2967
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8820 - accuracy: 0.9332 - scaled_graph_loss: 0.3047
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8752 - accuracy: 0.9346 - scaled_graph_loss: 0.2942
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9374 - scaled_graph_loss: 0.3066
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9332 - scaled_graph_loss: 0.2881
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8691 - accuracy: 0.9420 - scaled_graph_loss: 0.3030
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8631 - accuracy: 0.9374 - scaled_graph_loss: 0.2916
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9392 - scaled_graph_loss: 0.3032
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8632 - accuracy: 0.9420 - scaled_graph_loss: 0.3019
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8600 - accuracy: 0.9425 - scaled_graph_loss: 0.2965
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8569 - accuracy: 0.9346 - scaled_graph_loss: 0.2977
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8704 - accuracy: 0.9374 - scaled_graph_loss: 0.3083
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8562 - accuracy: 0.9406 - scaled_graph_loss: 0.2883
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8545 - accuracy: 0.9415 - scaled_graph_loss: 0.3030
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8592 - accuracy: 0.9332 - scaled_graph_loss: 0.2927
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8503 - accuracy: 0.9397 - scaled_graph_loss: 0.2927
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8434 - accuracy: 0.9462 - scaled_graph_loss: 0.2937
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8578 - accuracy: 0.9374 - scaled_graph_loss: 0.3064
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8504 - accuracy: 0.9411 - scaled_graph_loss: 0.3043
<keras.callbacks.History at 0x7f70041be650>

Оцените модель MLP с регуляризацией графа

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8884 - accuracy: 0.7957


Eval accuracy for  MLP + graph regularization :  0.7956600189208984
Eval loss for  MLP + graph regularization :  0.8883611559867859

Точность графа-регуляризованная модель составляет около 2-3% выше , чем у базовой модели ( base_model ).

Вывод

Мы продемонстрировали использование регуляризации графа для классификации документов на естественном графе цитирования (Cora) с использованием структуры нейронного структурированного обучения (NSL). Наш расширенный учебник включает в себя синтез графики , основанные на выборочных вложениях перед тренировкой нейронную сеть с графом регуляризации. Этот подход полезен, если входные данные не содержат явного графа.

Мы поощряем пользователей к дальнейшим экспериментам, варьируя степень контроля, а также пробуя разные нейронные архитектуры для регуляризации графов.