Join us at TensorFlow World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Graph regularization for document classification using natural graphs

View on TensorFlow.org Run in Google Colab View source on GitHub

Overview

Graph regularization is a specific technique under the broader paradigm of Neural Graph Learning (Bui et al., 2018). The core idea is to train neural network models with a graph-regularized objective, harnessing both labeled and unlabeled data.

In this tutorial, we will explore the use of graph regularization to classify documents that form a natural (organic) graph.

The general recipe for creating a graph-regularized model using the Neural Structured Learning (NSL) framework is as follows:

  1. Generate training data from the input graph and sample features. Nodes in the graph correspond to samples and edges in the graph correspond to similarity between pairs of samples. The resulting training data will contain neighbor features in addition to the original node features.
  2. Create a neural network as a base model using the Keras sequential, functional, or subclass API.
  3. Wrap the base model with the GraphRegularization wrapper class, which is provided by the NSL framework, to create a new graph Keras model. This new model will include a graph regularization loss as the regularization term in its training objective.
  4. Train and evaluate the graph Keras model.

Setup

  1. Install TensorFlow 2.x to create an interactive developing environment with eager execution.
  2. Install the Neural Structured Learning package.
!pip install --quiet tensorflow-gpu==2.0.0-rc0
!pip install --quiet neural-structured-learning

Dependencies and imports

from __future__ import absolute_import, division, print_function, unicode_literals

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")
Version:  2.0.0-rc0
Eager mode:  True
GPU is NOT AVAILABLE

Cora dataset

The Cora dataset is a citation graph where nodes represent machine learning papers and edges represent citations between pairs of papers. The task involved is document classification where the goal is to categorize each paper into one of 7 categories. In other words, this is a multi-class classification problem with 7 classes.

Graph

The original graph is directed. However, for the purpose of this example, we consider the undirected version of this graph. So, if paper A cites paper B, we also consider paper B to have cited A. Although this is not necessarily true, in this example, we consider citations as a proxy for similarity, which is usually a commutative property.

Features

Each paper in the input effectively contains 2 features:

  1. Words: A dense, multi-hot bag-of-words representation of the text in the paper. The vocabulary for the Cora dataset contains 1433 unique words. So, the length of this feature is 1433, and the value at position 'i' is 0/1 indicating whether word 'i' in the vocabulary exists in the given paper or not.

  2. Label: A single integer representing the class ID (category) of the paper.

Download the Cora dataset

!wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
!tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.content
cora/cora.cites

Convert the Cora data to the NSL format

In order to preprocess the Cora dataset and convert it to the format required by Neural Structured Learning, we will run the 'preprocess_cora_dataset.py' script, which is included in the NSL github repository. This script does the following:

  1. Generate neighbor features using the original node features and the graph.
  2. Generate train and test data splits containing tf.train.Example instances.
  3. Persist the resulting train and test data in the TFRecord format.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2019-09-18 12:10:11--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7327 (7.2K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]   7.16K  --.-KB/s    in 0s      

2019-09-18 12:10:11 (107 MB/s) - ‘preprocess_cora_dataset.py’ saved [7327/7327]

Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.01 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.85 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.06 minutes.

Global variables

The file paths to the train and test data are based on the command line flag values used to invoke the 'preprocess_cora_dataset.py' script above.

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

Hyperparameters

We will use an instance of HParams to inclue various hyperparameters and constants used for training and evaluation. We briefly describe each of them below:

  • num_classes: There are a total 7 different classes

  • max_seq_length: This is the size of the vocabulary and all instances in the input have a dense multi-hot, bag-of-words representation. In other words, a value of 1 for a word indicates that the word is present in the input and a value of 0 indicates that it is not.

  • distance_type: This is the distance metric used to regularize the sample with its neighbors.

  • graph_regularization_multiplier: This controls the relative weight of the graph regularization term in the overall loss function.

  • num_neighbors: The number of neighbors used for graph regularization.

  • num_fc_units: The number of fully connected layers in our neural network.

  • train_epochs: The number of training epochs.

  • batch_size: Batch size used for training and evaluation.

  • dropout_rate: Controls the rate of dropout following each fully connected layer

  • eval_steps: The number of batches to process before deeming evaluation is complete. If set to None, all instances in the test set are evaluated.

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

Load train and test data

As described earlier in this notebook, the input training and test data have been created by the 'preprocess_cora_dataset.py'. We will load them into two tf.data.Dataset objects -- one for train and one for test.

In the input layer of our model, we will extract not just the 'words' and the 'label' features from each sample, but also corresponding neighbor features based on the hparams.num_neighbors. Instances with fewer neighbors than hparams.num_neighbors will be assigned dummy values for those non-existent neighbor features.

def parse_example(example_proto):
  """Extracts relevant fields from the `example_proto`.

  Args:
    example_proto: An instance of `tf.train.Example`.

  Returns:
    A pair whose first value is a dictionary containing relevant features
    and whose second value contains the ground truth labels.
  """
  # The 'words' feature is a multi-hot, bag-of-words representation of the
  # original raw text. A default value is required for examples that don't
  # have the feature.
  feature_spec = {
      'words':
          tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                tf.int64,
                                default_value=tf.constant(
                                    0,
                                    dtype=tf.int64,
                                    shape=[HPARAMS.max_seq_length])),
      'label':
          tf.io.FixedLenFeature((), tf.int64, default_value=-1),
  }
  # We also extract corresponding neighbor features in a similar manner to
  # the features above.
  for i in range(HPARAMS.num_neighbors):
    nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
    nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)
    feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
        [HPARAMS.max_seq_length],
        tf.int64,
        default_value=tf.constant(
            0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

    # We assign a default value of 0.0 for the neighbor weight so that
    # graph regularization is done on samples based on their exact number
    # of neighbors. In other words, non-existent neighbors are discounted.
    feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
        [1], tf.float32, default_value=tf.constant([0.0]))

  features = tf.io.parse_single_example(example_proto, feature_spec)

  labels = features.pop('label')
  return features, labels


def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """
  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
WARNING:tensorflow:Entity <function parse_example at 0x7f05ad12c488> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num'
WARNING: Entity <function parse_example at 0x7f05ad12c488> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num'

Let's peek into the train dataset to look at its contents.

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['words', 'NL_nbr_0_weight', 'NL_nbr_0_words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 1 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[2 4 5 6 5 3 3 2 5 3 1 3 6 2 5 5 2 6 3 2 2 2 6 0 2 2 2 2 3 2 2 3 6 4 2 4 1
 2 6 4 6 1 6 2 3 1 1 2 3 6 2 6 2 2 6 5 0 0 6 6 4 3 6 6 2 2 0 5 6 2 4 2 2 4
 3 3 0 5 2 0 3 0 6 3 2 3 6 1 2 6 3 2 0 1 2 3 3 2 6 0 2 0 1 1 3 2 2 2 2 3 2
 3 4 1 6 1 1 1 2 6 1 2 3 3 1 1 2 3], shape=(128,), dtype=int64)

Let's peek into the test dataset to look at its contents.

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['words', 'NL_nbr_0_weight', 'NL_nbr_0_words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [1 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[0 1 1 2 0 2 1 2 1 3 2 6 2 3 2 3 2 3 6 6 3 2 1 3 2 1 2 2 1 2 2 2 0 1 6 2 1
 2 1 4 0 2 0 0 2 2 1 2 4 1 0 3 2 1 6 0 1 0 0 2 3 2 6 3 5 3 0 5 5 2 6 2 5 0
 6 5 0 6 3 2 4 2 5 0 6 2 4 0 3 2 2 0 3 2 6 2 2 2 6 6 3 1 6 6 6 2 2 0 0 5 3
 1 2 2 2 4 2 2 4 2 2 4 1 0 2 6 1 5], shape=(128,), dtype=int64)

Model definition

In order to demonstrate the use of graph regularization, we build a base model for this problem first. We will use a simple feed-forward neural network with 2 hidden layers and dropout in between. We illustrate the creation of the base model using all model types supported by the tf.Keras framework -- sequential, functional, and subclass.

Sequential base model

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
  return model

Functional base model

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(
      hparams.num_classes, activation='softmax')(
          cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

Subclass base model

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(
          hparams.num_classes, activation='softmax')

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

Create base model(s)

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
words (InputLayer)           [(None, 1433)]            0         
_________________________________________________________________
lambda (Lambda)              (None, 1433)              0         
_________________________________________________________________
dense (Dense)                (None, 50)                71700     
_________________________________________________________________
dropout (Dropout)            (None, 50)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 50)                2550      
_________________________________________________________________
dropout_1 (Dropout)          (None, 50)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 7)                 357       
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

Train base MLP model

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
WARNING:tensorflow:Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7f05247192f0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num'
WARNING: Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7f05247192f0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num'
17/17 [==============================] - 1s 42ms/step - loss: 1.9295 - accuracy: 0.2097
Epoch 2/100
17/17 [==============================] - 0s 16ms/step - loss: 1.8387 - accuracy: 0.3012
Epoch 3/100
17/17 [==============================] - 0s 16ms/step - loss: 1.7444 - accuracy: 0.3253
Epoch 4/100
17/17 [==============================] - 0s 15ms/step - loss: 1.6519 - accuracy: 0.3610
Epoch 5/100
17/17 [==============================] - 0s 17ms/step - loss: 1.5234 - accuracy: 0.4232
Epoch 6/100
17/17 [==============================] - 0s 16ms/step - loss: 1.4004 - accuracy: 0.4984
Epoch 7/100
17/17 [==============================] - 0s 16ms/step - loss: 1.2305 - accuracy: 0.5935
Epoch 8/100
17/17 [==============================] - 0s 16ms/step - loss: 1.1191 - accuracy: 0.6274
Epoch 9/100
17/17 [==============================] - 0s 17ms/step - loss: 0.9855 - accuracy: 0.6738
Epoch 10/100
17/17 [==============================] - 0s 17ms/step - loss: 0.9324 - accuracy: 0.7039
Epoch 11/100
17/17 [==============================] - 0s 21ms/step - loss: 0.7962 - accuracy: 0.7425
Epoch 12/100
17/17 [==============================] - 0s 20ms/step - loss: 0.7008 - accuracy: 0.7828
Epoch 13/100
17/17 [==============================] - 0s 21ms/step - loss: 0.6953 - accuracy: 0.7879
Epoch 14/100
17/17 [==============================] - 0s 20ms/step - loss: 0.6075 - accuracy: 0.8167
Epoch 15/100
17/17 [==============================] - 0s 21ms/step - loss: 0.5735 - accuracy: 0.8283
Epoch 16/100
17/17 [==============================] - 0s 19ms/step - loss: 0.5033 - accuracy: 0.8515
Epoch 17/100
17/17 [==============================] - 0s 21ms/step - loss: 0.4712 - accuracy: 0.8585
Epoch 18/100
17/17 [==============================] - 0s 19ms/step - loss: 0.4404 - accuracy: 0.8691
Epoch 19/100
17/17 [==============================] - 0s 20ms/step - loss: 0.3952 - accuracy: 0.8900
Epoch 20/100
17/17 [==============================] - 0s 21ms/step - loss: 0.3529 - accuracy: 0.8956
Epoch 21/100
17/17 [==============================] - 0s 20ms/step - loss: 0.3492 - accuracy: 0.8961
Epoch 22/100
17/17 [==============================] - 0s 22ms/step - loss: 0.3228 - accuracy: 0.9058
Epoch 23/100
17/17 [==============================] - 0s 23ms/step - loss: 0.3092 - accuracy: 0.9142
Epoch 24/100
17/17 [==============================] - 0s 19ms/step - loss: 0.2905 - accuracy: 0.9114
Epoch 25/100
17/17 [==============================] - 0s 21ms/step - loss: 0.2714 - accuracy: 0.9239
Epoch 26/100
17/17 [==============================] - 0s 21ms/step - loss: 0.2651 - accuracy: 0.9267
Epoch 27/100
17/17 [==============================] - 0s 21ms/step - loss: 0.2541 - accuracy: 0.9258
Epoch 28/100
17/17 [==============================] - 0s 21ms/step - loss: 0.2154 - accuracy: 0.9397
Epoch 29/100
17/17 [==============================] - 0s 20ms/step - loss: 0.2116 - accuracy: 0.9415
Epoch 30/100
17/17 [==============================] - 0s 21ms/step - loss: 0.2220 - accuracy: 0.9332
Epoch 31/100
17/17 [==============================] - 0s 20ms/step - loss: 0.1968 - accuracy: 0.9434
Epoch 32/100
17/17 [==============================] - 0s 21ms/step - loss: 0.2048 - accuracy: 0.9397
Epoch 33/100
17/17 [==============================] - 0s 21ms/step - loss: 0.1813 - accuracy: 0.9513
Epoch 34/100
17/17 [==============================] - 0s 21ms/step - loss: 0.1678 - accuracy: 0.9527
Epoch 35/100
17/17 [==============================] - 0s 20ms/step - loss: 0.1647 - accuracy: 0.9531
Epoch 36/100
17/17 [==============================] - 0s 21ms/step - loss: 0.1623 - accuracy: 0.9564
Epoch 37/100
17/17 [==============================] - 0s 20ms/step - loss: 0.1554 - accuracy: 0.9541
Epoch 38/100
17/17 [==============================] - 0s 21ms/step - loss: 0.1515 - accuracy: 0.9643
Epoch 39/100
17/17 [==============================] - 0s 21ms/step - loss: 0.1322 - accuracy: 0.9647
Epoch 40/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1413 - accuracy: 0.9624
Epoch 41/100
17/17 [==============================] - 0s 16ms/step - loss: 0.1500 - accuracy: 0.9564
Epoch 42/100
17/17 [==============================] - 0s 15ms/step - loss: 0.1357 - accuracy: 0.9629
Epoch 43/100
17/17 [==============================] - 0s 16ms/step - loss: 0.1267 - accuracy: 0.9684
Epoch 44/100
17/17 [==============================] - 0s 15ms/step - loss: 0.1253 - accuracy: 0.9652
Epoch 45/100
17/17 [==============================] - 0s 16ms/step - loss: 0.1330 - accuracy: 0.9619
Epoch 46/100
17/17 [==============================] - 0s 15ms/step - loss: 0.0981 - accuracy: 0.9754
Epoch 47/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1124 - accuracy: 0.9722
Epoch 48/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1035 - accuracy: 0.9731
Epoch 49/100
17/17 [==============================] - 0s 15ms/step - loss: 0.1026 - accuracy: 0.9754
Epoch 50/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1014 - accuracy: 0.9749
Epoch 51/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0967 - accuracy: 0.9740
Epoch 52/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1048 - accuracy: 0.9735
Epoch 53/100
17/17 [==============================] - 0s 19ms/step - loss: 0.1081 - accuracy: 0.9666
Epoch 54/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1056 - accuracy: 0.9717
Epoch 55/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1021 - accuracy: 0.9712
Epoch 56/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0900 - accuracy: 0.9777
Epoch 57/100
17/17 [==============================] - 0s 21ms/step - loss: 0.0870 - accuracy: 0.9740
Epoch 58/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0857 - accuracy: 0.9791
Epoch 59/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0886 - accuracy: 0.9754
Epoch 60/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0869 - accuracy: 0.9722
Epoch 61/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0769 - accuracy: 0.9787
Epoch 62/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0661 - accuracy: 0.9819
Epoch 63/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0756 - accuracy: 0.9773
Epoch 64/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0641 - accuracy: 0.9833
Epoch 65/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0696 - accuracy: 0.9810
Epoch 66/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0716 - accuracy: 0.9810
Epoch 67/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0673 - accuracy: 0.9824
Epoch 68/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0723 - accuracy: 0.9819
Epoch 69/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0675 - accuracy: 0.9814
Epoch 70/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0790 - accuracy: 0.9782
Epoch 71/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0580 - accuracy: 0.9879
Epoch 72/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0610 - accuracy: 0.9852
Epoch 73/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0685 - accuracy: 0.9796
Epoch 74/100
17/17 [==============================] - 0s 19ms/step - loss: 0.0670 - accuracy: 0.9814
Epoch 75/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0592 - accuracy: 0.9800
Epoch 76/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0694 - accuracy: 0.9814
Epoch 77/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0612 - accuracy: 0.9819
Epoch 78/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0618 - accuracy: 0.9824
Epoch 79/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0570 - accuracy: 0.9847
Epoch 80/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0633 - accuracy: 0.9814
Epoch 81/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0527 - accuracy: 0.9875
Epoch 82/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0624 - accuracy: 0.9833
Epoch 83/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0516 - accuracy: 0.9875
Epoch 84/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0503 - accuracy: 0.9889
Epoch 85/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0533 - accuracy: 0.9865
Epoch 86/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0534 - accuracy: 0.9875
Epoch 87/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0493 - accuracy: 0.9884
Epoch 88/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0642 - accuracy: 0.9782
Epoch 89/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0440 - accuracy: 0.9875
Epoch 90/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0551 - accuracy: 0.9828
Epoch 91/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0480 - accuracy: 0.9856
Epoch 92/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0516 - accuracy: 0.9865
Epoch 93/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0535 - accuracy: 0.9828
Epoch 94/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0501 - accuracy: 0.9852
Epoch 95/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0435 - accuracy: 0.9852
Epoch 96/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0473 - accuracy: 0.9842
Epoch 97/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0490 - accuracy: 0.9865
Epoch 98/100
17/17 [==============================] - 0s 15ms/step - loss: 0.0490 - accuracy: 0.9870
Epoch 99/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0465 - accuracy: 0.9879
Epoch 100/100
17/17 [==============================] - 0s 15ms/step - loss: 0.0447 - accuracy: 0.9893

<tensorflow.python.keras.callbacks.History at 0x7f05247f37b8>

Evaluate base MLP model

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 31ms/step - loss: 1.3591 - accuracy: 0.7902


Eval accuracy for  Base MLP model :  0.7902351
Eval loss for  Base MLP model :  1.3590900421142578

Train MLP model with graph regularization

Incorporating graph regularization into the loss term of an existing tf.Keras.Model requires just a few lines of code. The base model is wrapped to create a new tf.Keras subclass model, whose loss includes graph regularization.

To assess the incremental benefit of graph regularization, we will create a new base model instance. This is because base_model has already been trained for a few iterations, and reusing this trained model to create a graph-regularized model will not be a fair comparison for base_model.

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
WARNING:tensorflow:Entity <bound method GraphRegularization.call of <neural_structured_learning.keras.graph_regularization.GraphRegularization object at 0x7f0514433dd8>> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4
WARNING: Entity <bound method GraphRegularization.call of <neural_structured_learning.keras.graph_regularization.GraphRegularization object at 0x7f0514433dd8>> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/losses/losses_impl.py:121: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Epoch 1/100
WARNING:tensorflow:Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7f05a1df0f28> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num'

/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/indexed_slices.py:424: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

WARNING: Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7f05a1df0f28> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num'

/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/framework/indexed_slices.py:424: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

17/17 [==============================] - 2s 136ms/step - loss: 1.9318 - accuracy: 0.2111 - graph_loss: 0.0073
Epoch 2/100
17/17 [==============================] - 0s 18ms/step - loss: 1.8529 - accuracy: 0.3086 - graph_loss: 0.0105
Epoch 3/100
17/17 [==============================] - 0s 17ms/step - loss: 1.7521 - accuracy: 0.3304 - graph_loss: 0.0246
Epoch 4/100
17/17 [==============================] - 0s 19ms/step - loss: 1.6685 - accuracy: 0.3397 - graph_loss: 0.0441
Epoch 5/100
17/17 [==============================] - 0s 18ms/step - loss: 1.5470 - accuracy: 0.4079 - graph_loss: 0.0649
Epoch 6/100
17/17 [==============================] - 0s 18ms/step - loss: 1.4116 - accuracy: 0.5026 - graph_loss: 0.0955
Epoch 7/100
17/17 [==============================] - 0s 17ms/step - loss: 1.2673 - accuracy: 0.5694 - graph_loss: 0.1408
Epoch 8/100
17/17 [==============================] - 0s 18ms/step - loss: 1.1319 - accuracy: 0.6260 - graph_loss: 0.1818
Epoch 9/100
17/17 [==============================] - 0s 17ms/step - loss: 1.0060 - accuracy: 0.6756 - graph_loss: 0.2203
Epoch 10/100
17/17 [==============================] - 0s 17ms/step - loss: 0.9097 - accuracy: 0.7295 - graph_loss: 0.2442
Epoch 11/100
17/17 [==============================] - 0s 17ms/step - loss: 0.8206 - accuracy: 0.7420 - graph_loss: 0.2610
Epoch 12/100
17/17 [==============================] - 0s 17ms/step - loss: 0.7341 - accuracy: 0.7754 - graph_loss: 0.2893
Epoch 13/100
17/17 [==============================] - 0s 18ms/step - loss: 0.6703 - accuracy: 0.8056 - graph_loss: 0.2837
Epoch 14/100
17/17 [==============================] - 0s 18ms/step - loss: 0.6222 - accuracy: 0.8218 - graph_loss: 0.2945
Epoch 15/100
17/17 [==============================] - 0s 17ms/step - loss: 0.5343 - accuracy: 0.8520 - graph_loss: 0.3130
Epoch 16/100
17/17 [==============================] - 0s 18ms/step - loss: 0.5301 - accuracy: 0.8478 - graph_loss: 0.3121
Epoch 17/100
17/17 [==============================] - 0s 17ms/step - loss: 0.4815 - accuracy: 0.8668 - graph_loss: 0.3149
Epoch 18/100
17/17 [==============================] - 0s 18ms/step - loss: 0.4556 - accuracy: 0.8687 - graph_loss: 0.3178
Epoch 19/100
17/17 [==============================] - 0s 18ms/step - loss: 0.4389 - accuracy: 0.8766 - graph_loss: 0.3257
Epoch 20/100
17/17 [==============================] - 0s 18ms/step - loss: 0.4021 - accuracy: 0.8933 - graph_loss: 0.3222
Epoch 21/100
17/17 [==============================] - 0s 17ms/step - loss: 0.3662 - accuracy: 0.9039 - graph_loss: 0.3287
Epoch 22/100
17/17 [==============================] - 0s 17ms/step - loss: 0.3348 - accuracy: 0.9174 - graph_loss: 0.3266
Epoch 23/100
17/17 [==============================] - 0s 18ms/step - loss: 0.3626 - accuracy: 0.8979 - graph_loss: 0.3419
Epoch 24/100
17/17 [==============================] - 0s 16ms/step - loss: 0.3080 - accuracy: 0.9183 - graph_loss: 0.3367
Epoch 25/100
17/17 [==============================] - 0s 18ms/step - loss: 0.3172 - accuracy: 0.9211 - graph_loss: 0.3347
Epoch 26/100
17/17 [==============================] - 0s 17ms/step - loss: 0.2845 - accuracy: 0.9281 - graph_loss: 0.3436
Epoch 27/100
17/17 [==============================] - 0s 18ms/step - loss: 0.2662 - accuracy: 0.9350 - graph_loss: 0.3344
Epoch 28/100
17/17 [==============================] - 0s 18ms/step - loss: 0.2618 - accuracy: 0.9374 - graph_loss: 0.3302
Epoch 29/100
17/17 [==============================] - 0s 18ms/step - loss: 0.2576 - accuracy: 0.9341 - graph_loss: 0.3360
Epoch 30/100
17/17 [==============================] - 0s 17ms/step - loss: 0.2273 - accuracy: 0.9508 - graph_loss: 0.3423
Epoch 31/100
17/17 [==============================] - 0s 19ms/step - loss: 0.2400 - accuracy: 0.9429 - graph_loss: 0.3374
Epoch 32/100
17/17 [==============================] - 0s 18ms/step - loss: 0.2134 - accuracy: 0.9564 - graph_loss: 0.3388
Epoch 33/100
17/17 [==============================] - 0s 19ms/step - loss: 0.2058 - accuracy: 0.9513 - graph_loss: 0.3468
Epoch 34/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1973 - accuracy: 0.9531 - graph_loss: 0.3429
Epoch 35/100
17/17 [==============================] - 0s 17ms/step - loss: 0.2020 - accuracy: 0.9499 - graph_loss: 0.3417
Epoch 36/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1846 - accuracy: 0.9587 - graph_loss: 0.3403
Epoch 37/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1832 - accuracy: 0.9596 - graph_loss: 0.3369
Epoch 38/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1718 - accuracy: 0.9680 - graph_loss: 0.3407
Epoch 39/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1595 - accuracy: 0.9643 - graph_loss: 0.3384
Epoch 40/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1672 - accuracy: 0.9619 - graph_loss: 0.3420
Epoch 41/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1547 - accuracy: 0.9726 - graph_loss: 0.3454
Epoch 42/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1698 - accuracy: 0.9619 - graph_loss: 0.3447
Epoch 43/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1692 - accuracy: 0.9633 - graph_loss: 0.3497
Epoch 44/100
17/17 [==============================] - 0s 16ms/step - loss: 0.1549 - accuracy: 0.9680 - graph_loss: 0.3331
Epoch 45/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1351 - accuracy: 0.9754 - graph_loss: 0.3418
Epoch 46/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1410 - accuracy: 0.9749 - graph_loss: 0.3320
Epoch 47/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1406 - accuracy: 0.9698 - graph_loss: 0.3456
Epoch 48/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1449 - accuracy: 0.9722 - graph_loss: 0.3416
Epoch 49/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1361 - accuracy: 0.9708 - graph_loss: 0.3486
Epoch 50/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1435 - accuracy: 0.9731 - graph_loss: 0.3381
Epoch 51/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1275 - accuracy: 0.9782 - graph_loss: 0.3392
Epoch 52/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1291 - accuracy: 0.9763 - graph_loss: 0.3412
Epoch 53/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1310 - accuracy: 0.9754 - graph_loss: 0.3419
Epoch 54/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1301 - accuracy: 0.9763 - graph_loss: 0.3419
Epoch 55/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1182 - accuracy: 0.9787 - graph_loss: 0.3395
Epoch 56/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1160 - accuracy: 0.9805 - graph_loss: 0.3454
Epoch 57/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1228 - accuracy: 0.9740 - graph_loss: 0.3528
Epoch 58/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1186 - accuracy: 0.9773 - graph_loss: 0.3447
Epoch 59/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1154 - accuracy: 0.9759 - graph_loss: 0.3392
Epoch 60/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1231 - accuracy: 0.9740 - graph_loss: 0.3423
Epoch 61/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1104 - accuracy: 0.9791 - graph_loss: 0.3313
Epoch 62/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1131 - accuracy: 0.9800 - graph_loss: 0.3427
Epoch 63/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1076 - accuracy: 0.9791 - graph_loss: 0.3418
Epoch 64/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1063 - accuracy: 0.9814 - graph_loss: 0.3308
Epoch 65/100
17/17 [==============================] - 0s 19ms/step - loss: 0.1021 - accuracy: 0.9833 - graph_loss: 0.3389
Epoch 66/100
17/17 [==============================] - 0s 17ms/step - loss: 0.1056 - accuracy: 0.9828 - graph_loss: 0.3427
Epoch 67/100
17/17 [==============================] - 0s 18ms/step - loss: 0.1032 - accuracy: 0.9796 - graph_loss: 0.3464
Epoch 68/100
17/17 [==============================] - 0s 19ms/step - loss: 0.1021 - accuracy: 0.9819 - graph_loss: 0.3476
Epoch 69/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0966 - accuracy: 0.9856 - graph_loss: 0.3373
Epoch 70/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0996 - accuracy: 0.9805 - graph_loss: 0.3423
Epoch 71/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0962 - accuracy: 0.9824 - graph_loss: 0.3327
Epoch 72/100
17/17 [==============================] - 0s 19ms/step - loss: 0.1008 - accuracy: 0.9828 - graph_loss: 0.3440
Epoch 73/100
17/17 [==============================] - 0s 19ms/step - loss: 0.0801 - accuracy: 0.9907 - graph_loss: 0.3403
Epoch 74/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0961 - accuracy: 0.9838 - graph_loss: 0.3356
Epoch 75/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0953 - accuracy: 0.9838 - graph_loss: 0.3437
Epoch 76/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0969 - accuracy: 0.9791 - graph_loss: 0.3370
Epoch 77/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0948 - accuracy: 0.9833 - graph_loss: 0.3434
Epoch 78/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0865 - accuracy: 0.9879 - graph_loss: 0.3265
Epoch 79/100
17/17 [==============================] - 0s 19ms/step - loss: 0.0988 - accuracy: 0.9856 - graph_loss: 0.3388
Epoch 80/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0933 - accuracy: 0.9852 - graph_loss: 0.3402
Epoch 81/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0896 - accuracy: 0.9870 - graph_loss: 0.3372
Epoch 82/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0929 - accuracy: 0.9838 - graph_loss: 0.3407
Epoch 83/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0870 - accuracy: 0.9852 - graph_loss: 0.3356
Epoch 84/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0916 - accuracy: 0.9852 - graph_loss: 0.3365
Epoch 85/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0841 - accuracy: 0.9870 - graph_loss: 0.3402
Epoch 86/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0862 - accuracy: 0.9884 - graph_loss: 0.3396
Epoch 87/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0751 - accuracy: 0.9889 - graph_loss: 0.3385
Epoch 88/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0835 - accuracy: 0.9861 - graph_loss: 0.3394
Epoch 89/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0824 - accuracy: 0.9898 - graph_loss: 0.3314
Epoch 90/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0779 - accuracy: 0.9861 - graph_loss: 0.3343
Epoch 91/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0762 - accuracy: 0.9907 - graph_loss: 0.3421
Epoch 92/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0868 - accuracy: 0.9847 - graph_loss: 0.3453
Epoch 93/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0778 - accuracy: 0.9879 - graph_loss: 0.3433
Epoch 94/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0886 - accuracy: 0.9842 - graph_loss: 0.3347
Epoch 95/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0789 - accuracy: 0.9903 - graph_loss: 0.3340
Epoch 96/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0720 - accuracy: 0.9916 - graph_loss: 0.3406
Epoch 97/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0806 - accuracy: 0.9884 - graph_loss: 0.3388
Epoch 98/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0781 - accuracy: 0.9893 - graph_loss: 0.3353
Epoch 99/100
17/17 [==============================] - 0s 18ms/step - loss: 0.0785 - accuracy: 0.9893 - graph_loss: 0.3307
Epoch 100/100
17/17 [==============================] - 0s 17ms/step - loss: 0.0770 - accuracy: 0.9893 - graph_loss: 0.3405

<tensorflow.python.keras.callbacks.History at 0x7f05143aa940>

Evaluate MLP model with graph regularization

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 76ms/step - loss: 1.2379 - accuracy: 0.8029 - graph_loss: 0.0000e+00


Eval accuracy for  MLP + graph regularization :  0.8028933
Eval loss for  MLP + graph regularization :  1.2379332304000854
Eval graph loss for  MLP + graph regularization :  0.0

The graph-regularized model's accuracy is about 2-3% higher than that of the base model (base_model).

Conclusion

We have demonstrated the use of graph regularization for document classification on a natural citation graph (Cora) using the Neural Structured Learning (NSL) framework. Our advanced tutorial involves synthesizing graphs based on sample embeddings before training a neural network with graph regularization. This approach is useful if the input does not contain an explicit graph.

We encourage users to experiment further by varying the amount of supervision as well as trying different neural architectures for graph regularization.