Regolarizzazione di grafici per la classificazione dei documenti mediante grafici naturali

Mantieni tutto organizzato con le raccolte Salva e classifica i contenuti in base alle tue preferenze.

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

Panoramica

Grafico regolarizzazione è una tecnica specifica sotto il paradigma più ampio di Neural Graph Learning ( Bui et al., 2018 ). L'idea centrale è quella di addestrare modelli di rete neurale con un obiettivo regolato dal grafico, sfruttando sia i dati etichettati che quelli non etichettati.

In questo tutorial, esploreremo l'uso della regolarizzazione del grafico per classificare i documenti che formano un grafico naturale (organico).

La ricetta generale per la creazione di un modello grafico-regolato utilizzando il framework Neural Structured Learning (NSL) è la seguente:

  1. Genera dati di addestramento dal grafico di input e dalle funzionalità di esempio. I nodi nel grafico corrispondono ai campioni e gli archi nel grafico corrispondono alla somiglianza tra coppie di campioni. I dati di training risultanti conterranno feature adiacenti oltre alle feature del nodo originale.
  2. Creare una rete neurale come modello base utilizzando la Keras sequenziale, funzionale, o API sottoclasse.
  3. Avvolgere il modello base con il GraphRegularization classe wrapper, che è fornito dal framework NSL, per creare un nuovo grafico Keras modello. Questo nuovo modello includerà un grafico della perdita di regolarizzazione come termine di regolarizzazione nel suo obiettivo di formazione.
  4. Addestrare e valutare il grafico Keras modello.

Impostare

Installa il pacchetto Neural Structured Learning.

pip install --quiet neural-structured-learning

Dipendenze e importazioni

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version:  2.8.0-rc0
Eager mode:  True
GPU is NOT AVAILABLE
2022-01-05 12:39:27.704660: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Set di dati Cora

L' insieme di dati Cora è un grafico citazione in cui i nodi rappresentano carte di apprendimento automatico e bordi rappresentano citazioni tra coppie di carte. L'attività coinvolta è la classificazione dei documenti in cui l'obiettivo è classificare ogni documento in una delle 7 categorie. In altre parole, questo è un problema di classificazione multiclasse con 7 classi.

Grafico

Il grafico originale è diretto. Tuttavia, ai fini di questo esempio, consideriamo la versione non orientata di questo grafico. Quindi, se l'articolo A cita l'articolo B, consideriamo che anche l'articolo B abbia citato A. Sebbene ciò non sia necessariamente vero, in questo esempio, consideriamo le citazioni come proxy per la somiglianza, che di solito è una proprietà commutativa.

Caratteristiche

Ogni carta nell'input contiene effettivamente 2 caratteristiche:

  1. Le parole: Una fitta, più calda la rappresentazione bag-of-parole del testo nel documento. Il vocabolario per il set di dati Cora contiene 1433 parole uniche. Quindi, la lunghezza di questa caratteristica è 1433, e il valore alla posizione 'i' è 0/1 che indica se la parola 'i' nel vocabolario esiste o meno nel documento dato.

  2. Etichetta: Un unico numero intero che rappresenta l'ID di classe (categoria) della carta.

Scarica il set di dati Cora

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

Converti i dati Cora nel formato NSL

Al fine di pre-elaborare i set di dati Cora e convertirlo nel formato richiesto dalla Neural strutturato di apprendimento, ci verrà eseguito lo script 'preprocess_cora_dataset.py', che è incluso nel repository github NSL. Questo script esegue le seguenti operazioni:

  1. Genera feature adiacenti utilizzando le feature del nodo originale e il grafico.
  2. Generare spaccature dei treni e dei dati di test che contengono tf.train.Example istanze.
  3. Persistere il treno risultante e dati di test nel TFRecord formato.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2022-01-05 12:39:28--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0s      

2022-01-05 12:39:28 (78.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2022-01-05 12:39:31.378912: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.01 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.36 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.04 minutes.

Variabili globali

I percorsi dei file ai dati dei treni e dei test sono basati sui valori di flag della riga di comando usato per richiamare lo script 'preprocess_cora_dataset.py' sopra.

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

Iperparametri

Useremo un'istanza di HParams per includere vari iperparametri e costanti utilizzate per la formazione e la valutazione. Descriviamo brevemente ciascuno di essi di seguito:

  • num_classes: Ci sono un totale di 7 classi diverse

  • max_seq_length: Questa è la dimensione del vocabolario e tutte le istanze in ingresso hanno una densa multi-caldo, il sacchetto-di-parole di rappresentanza. In altre parole, un valore 1 per una parola indica che la parola è presente nell'ingresso e un valore 0 indica che non lo è.

  • distance_type: Questa è la distanza metrica utilizzata per regolarizzare il campione con i suoi vicini.

  • graph_regularization_multiplier: Questo controlla il peso relativo del termine grafico regolarizzazione nella funzione generale di perdita.

  • num_neighbors: il numero di vicini utilizzati per la regolarizzazione del grafico. Tale valore deve essere minore o uguale ai max_nbrs comando di riga argomento usato sopra durante l'esecuzione preprocess_cora_dataset.py .

  • num_fc_units: Il numero di strati completamente collegati in una rete neurale.

  • train_epochs: il numero di epoche di formazione.

  • dimensione del lotto utilizzato per la formazione e la valutazione: batch_size.

  • dropout_rate: controlla la velocità della dispersione dopo ciascuna strato completamente connesso

  • eval_steps: il numero di lotti di processo prima ritenendo di valutazione è completo. Se è impostato su None , tutte le istanze nel set di prova vengono valutati.

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

Caricare i dati del treno e dei test

Come descritto in precedenza in questo notebook, i dati di allenamento e di test di ingresso sono stati creati dal 'preprocess_cora_dataset.py'. Noi caricarli in due tf.data.Dataset oggetti - uno per il treno e uno per il test.

Nello strato di input del nostro modello, avremo estrarre non solo le 'parole' e 'etichetta' dotato di ciascun campione, ma anche vicina corrispondenti funzioni basate sul hparams.num_neighbors valore. Le istanze con un minor numero di vicini di casa che hparams.num_neighbors verranno assegnati manichino valori per quelle caratteristiche vicini inesistenti.

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

Diamo un'occhiata al set di dati del treno per esaminarne il contenuto.

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[2 2 6 2 0 6 1 3 5 0 1 2 3 6 1 1 0 3 5 2 3 1 4 1 6 1 3 2 2 2 0 3 2 1 3 3 2
 3 3 2 3 2 2 0 2 2 6 0 2 1 1 0 5 2 1 4 2 1 2 4 0 2 5 4 3 6 3 2 1 6 2 4 2 2
 6 4 6 4 3 5 2 2 2 4 2 2 2 1 2 2 2 4 2 3 6 2 0 6 6 0 2 6 2 1 2 0 1 1 3 2 0
 2 0 2 1 1 3 5 2 1 2 5 1 6 2 4 6 4], shape=(128,), dtype=int64)

Diamo un'occhiata al set di dati di test per esaminarne il contenuto.

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

Definizione del modello

Per dimostrare l'uso della regolarizzazione dei grafi, prima costruiamo un modello di base per questo problema. Useremo una semplice rete neurale feed-forward con 2 strati nascosti e dropout in mezzo. Illustriamo la creazione del modello di base con tutti i tipi di modello supportati dal tf.Keras quadro - sequenziale, funzionale e sottoclasse.

Modello base sequenziale

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes))
  return model

Modello base funzionale

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

Modello base della sottoclasse

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(hparams.num_classes)

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

Crea modelli base

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 words (InputLayer)          [(None, 1433)]            0         
                                                                 
 lambda (Lambda)             (None, 1433)              0         
                                                                 
 dense (Dense)               (None, 50)                71700     
                                                                 
 dropout (Dropout)           (None, 50)                0         
                                                                 
 dense_1 (Dense)             (None, 50)                2550      
                                                                 
 dropout_1 (Dropout)         (None, 50)                0         
                                                                 
 dense_2 (Dense)             (None, 7)                 357       
                                                                 
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

Modello MLP di base del treno

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)
17/17 [==============================] - 1s 18ms/step - loss: 1.9521 - accuracy: 0.1838
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8590 - accuracy: 0.3044
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7770 - accuracy: 0.3601
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6655 - accuracy: 0.3898
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5386 - accuracy: 0.4543
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3856 - accuracy: 0.5077
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2736 - accuracy: 0.5531
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1636 - accuracy: 0.5889
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0654 - accuracy: 0.6385
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9703 - accuracy: 0.6761
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8689 - accuracy: 0.7104
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7704 - accuracy: 0.7494
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7157 - accuracy: 0.7810
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 0.6296 - accuracy: 0.8186
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5932 - accuracy: 0.8167
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5526 - accuracy: 0.8464
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5112 - accuracy: 0.8445
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8613
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4163 - accuracy: 0.8696
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3808 - accuracy: 0.8849
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3564 - accuracy: 0.8933
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3453 - accuracy: 0.9002
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3226 - accuracy: 0.9114
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3058 - accuracy: 0.9151
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2798 - accuracy: 0.9146
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2638 - accuracy: 0.9248
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2538 - accuracy: 0.9290
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2356 - accuracy: 0.9411
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2080 - accuracy: 0.9425
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2172 - accuracy: 0.9364
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2259 - accuracy: 0.9225
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1944 - accuracy: 0.9480
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1892 - accuracy: 0.9434
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1718 - accuracy: 0.9592
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1826 - accuracy: 0.9508
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1585 - accuracy: 0.9559
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1605 - accuracy: 0.9545
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1529 - accuracy: 0.9550
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1411 - accuracy: 0.9615
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1366 - accuracy: 0.9624
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1431 - accuracy: 0.9578
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1241 - accuracy: 0.9619
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1310 - accuracy: 0.9661
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1284 - accuracy: 0.9652
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1215 - accuracy: 0.9633
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1130 - accuracy: 0.9722
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1074 - accuracy: 0.9722
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1143 - accuracy: 0.9694
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1015 - accuracy: 0.9740
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1077 - accuracy: 0.9698
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1035 - accuracy: 0.9684
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9694
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1000 - accuracy: 0.9689
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0967 - accuracy: 0.9749
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0994 - accuracy: 0.9703
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0943 - accuracy: 0.9740
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0923 - accuracy: 0.9735
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0848 - accuracy: 0.9800
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0836 - accuracy: 0.9782
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0913 - accuracy: 0.9735
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0823 - accuracy: 0.9773
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0753 - accuracy: 0.9810
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0746 - accuracy: 0.9777
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0861 - accuracy: 0.9731
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0765 - accuracy: 0.9787
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0750 - accuracy: 0.9791
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0725 - accuracy: 0.9814
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 0.9791
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0645 - accuracy: 0.9842
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0606 - accuracy: 0.9861
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0775 - accuracy: 0.9805
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0655 - accuracy: 0.9800
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0629 - accuracy: 0.9833
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0625 - accuracy: 0.9824
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0607 - accuracy: 0.9838
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0578 - accuracy: 0.9824
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0568 - accuracy: 0.9842
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0595 - accuracy: 0.9833
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0615 - accuracy: 0.9842
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0555 - accuracy: 0.9852
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0517 - accuracy: 0.9870
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0541 - accuracy: 0.9856
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9884
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0509 - accuracy: 0.9838
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0600 - accuracy: 0.9828
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0617 - accuracy: 0.9800
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0599 - accuracy: 0.9800
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0502 - accuracy: 0.9870
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0416 - accuracy: 0.9907
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0542 - accuracy: 0.9842
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0490 - accuracy: 0.9847
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0374 - accuracy: 0.9916
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9893
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 0.9879
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0543 - accuracy: 0.9861
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0420 - accuracy: 0.9870
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0461 - accuracy: 0.9861
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0425 - accuracy: 0.9898
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0406 - accuracy: 0.9907
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0486 - accuracy: 0.9847
<keras.callbacks.History at 0x7f6f9d5eacd0>

Valuta il modello MLP di base

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4192 - accuracy: 0.7939


Eval accuracy for  Base MLP model :  0.7938517332077026
Eval loss for  Base MLP model :  1.4192423820495605

Addestra il modello MLP con la regolarizzazione del grafico

Incorporando grafico regolarizzazione nel termine perdita di un esistente tf.Keras.Model richiede solo poche righe di codice. Il modello di base è avvolto per creare un nuovo tf.Keras sottoclassi modello, la cui perdita include grafico regolarizzazione.

Per valutare il vantaggio incrementale della regolarizzazione del grafico, creeremo una nuova istanza del modello di base. Questo perché base_model è già stato allenato per un paio di iterazioni, e il riutilizzo di questo modello addestrato per creare un modello grafico-regolarizzato non sarà un confronto equo per base_model .

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "shape. This may consume a large amount of memory." % value)
17/17 [==============================] - 2s 4ms/step - loss: 1.9798 - accuracy: 0.1601 - scaled_graph_loss: 0.0373
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.9024 - accuracy: 0.2979 - scaled_graph_loss: 0.0254
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8623 - accuracy: 0.3160 - scaled_graph_loss: 0.0317
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8042 - accuracy: 0.3443 - scaled_graph_loss: 0.0498
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7552 - accuracy: 0.3582 - scaled_graph_loss: 0.0696
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7012 - accuracy: 0.4084 - scaled_graph_loss: 0.0866
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6578 - accuracy: 0.4515 - scaled_graph_loss: 0.1114
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6058 - accuracy: 0.5039 - scaled_graph_loss: 0.1300
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5498 - accuracy: 0.5434 - scaled_graph_loss: 0.1508
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5098 - accuracy: 0.6019 - scaled_graph_loss: 0.1651
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4746 - accuracy: 0.6302 - scaled_graph_loss: 0.1844
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4315 - accuracy: 0.6520 - scaled_graph_loss: 0.1917
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3932 - accuracy: 0.6770 - scaled_graph_loss: 0.2024
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3645 - accuracy: 0.7183 - scaled_graph_loss: 0.2145
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3265 - accuracy: 0.7369 - scaled_graph_loss: 0.2324
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3045 - accuracy: 0.7555 - scaled_graph_loss: 0.2358
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2836 - accuracy: 0.7652 - scaled_graph_loss: 0.2404
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2456 - accuracy: 0.7898 - scaled_graph_loss: 0.2469
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2348 - accuracy: 0.8074 - scaled_graph_loss: 0.2615
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2000 - accuracy: 0.8074 - scaled_graph_loss: 0.2542
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1994 - accuracy: 0.8260 - scaled_graph_loss: 0.2729
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1825 - accuracy: 0.8269 - scaled_graph_loss: 0.2676
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1598 - accuracy: 0.8455 - scaled_graph_loss: 0.2742
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8534 - scaled_graph_loss: 0.2797
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1456 - accuracy: 0.8552 - scaled_graph_loss: 0.2714
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8566 - scaled_graph_loss: 0.2796
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1150 - accuracy: 0.8687 - scaled_graph_loss: 0.2850
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8626 - scaled_graph_loss: 0.2772
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0806 - accuracy: 0.8733 - scaled_graph_loss: 0.2756
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0828 - accuracy: 0.8626 - scaled_graph_loss: 0.2907
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0724 - accuracy: 0.8886 - scaled_graph_loss: 0.2834
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0589 - accuracy: 0.8826 - scaled_graph_loss: 0.2881
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0490 - accuracy: 0.8872 - scaled_graph_loss: 0.2972
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0550 - accuracy: 0.8923 - scaled_graph_loss: 0.2935
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0397 - accuracy: 0.8840 - scaled_graph_loss: 0.2795
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0360 - accuracy: 0.8891 - scaled_graph_loss: 0.2966
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0235 - accuracy: 0.8961 - scaled_graph_loss: 0.2890
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0219 - accuracy: 0.8984 - scaled_graph_loss: 0.2965
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0168 - accuracy: 0.9044 - scaled_graph_loss: 0.3023
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0148 - accuracy: 0.9035 - scaled_graph_loss: 0.2984
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9118 - scaled_graph_loss: 0.2888
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0019 - accuracy: 0.9021 - scaled_graph_loss: 0.2877
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9049 - scaled_graph_loss: 0.2912
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9986 - accuracy: 0.9026 - scaled_graph_loss: 0.3040
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9939 - accuracy: 0.9067 - scaled_graph_loss: 0.3016
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9828 - accuracy: 0.9058 - scaled_graph_loss: 0.2877
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9629 - accuracy: 0.9137 - scaled_graph_loss: 0.2844
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9645 - accuracy: 0.9146 - scaled_graph_loss: 0.2933
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9752 - accuracy: 0.9165 - scaled_graph_loss: 0.3013
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9552 - accuracy: 0.9179 - scaled_graph_loss: 0.2865
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9539 - accuracy: 0.9193 - scaled_graph_loss: 0.3044
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9443 - accuracy: 0.9183 - scaled_graph_loss: 0.3010
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9559 - accuracy: 0.9244 - scaled_graph_loss: 0.2987
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9497 - accuracy: 0.9225 - scaled_graph_loss: 0.2979
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9674 - accuracy: 0.9183 - scaled_graph_loss: 0.3034
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9537 - accuracy: 0.9174 - scaled_graph_loss: 0.2834
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9341 - accuracy: 0.9188 - scaled_graph_loss: 0.2939
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9392 - accuracy: 0.9225 - scaled_graph_loss: 0.2998
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9240 - accuracy: 0.9313 - scaled_graph_loss: 0.3022
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9368 - accuracy: 0.9267 - scaled_graph_loss: 0.2979
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9306 - accuracy: 0.9234 - scaled_graph_loss: 0.2952
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9197 - accuracy: 0.9230 - scaled_graph_loss: 0.2916
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9360 - accuracy: 0.9206 - scaled_graph_loss: 0.2947
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9181 - accuracy: 0.9299 - scaled_graph_loss: 0.2996
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9105 - accuracy: 0.9341 - scaled_graph_loss: 0.2981
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9014 - accuracy: 0.9323 - scaled_graph_loss: 0.2897
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9059 - accuracy: 0.9364 - scaled_graph_loss: 0.3083
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9053 - accuracy: 0.9309 - scaled_graph_loss: 0.2976
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9099 - accuracy: 0.9258 - scaled_graph_loss: 0.3069
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9025 - accuracy: 0.9355 - scaled_graph_loss: 0.2890
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8849 - accuracy: 0.9281 - scaled_graph_loss: 0.2933
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8959 - accuracy: 0.9323 - scaled_graph_loss: 0.2918
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9074 - accuracy: 0.9248 - scaled_graph_loss: 0.3065
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8845 - accuracy: 0.9369 - scaled_graph_loss: 0.2874
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8873 - accuracy: 0.9401 - scaled_graph_loss: 0.2996
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8942 - accuracy: 0.9327 - scaled_graph_loss: 0.3086
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9052 - accuracy: 0.9253 - scaled_graph_loss: 0.2986
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8811 - accuracy: 0.9336 - scaled_graph_loss: 0.2948
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8896 - accuracy: 0.9276 - scaled_graph_loss: 0.2919
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8853 - accuracy: 0.9313 - scaled_graph_loss: 0.2944
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8875 - accuracy: 0.9323 - scaled_graph_loss: 0.2925
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8639 - accuracy: 0.9323 - scaled_graph_loss: 0.2967
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8820 - accuracy: 0.9332 - scaled_graph_loss: 0.3047
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8752 - accuracy: 0.9346 - scaled_graph_loss: 0.2942
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9374 - scaled_graph_loss: 0.3066
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9332 - scaled_graph_loss: 0.2881
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8691 - accuracy: 0.9420 - scaled_graph_loss: 0.3030
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8631 - accuracy: 0.9374 - scaled_graph_loss: 0.2916
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9392 - scaled_graph_loss: 0.3032
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8632 - accuracy: 0.9420 - scaled_graph_loss: 0.3019
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8600 - accuracy: 0.9425 - scaled_graph_loss: 0.2965
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8569 - accuracy: 0.9346 - scaled_graph_loss: 0.2977
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8704 - accuracy: 0.9374 - scaled_graph_loss: 0.3083
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8562 - accuracy: 0.9406 - scaled_graph_loss: 0.2883
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8545 - accuracy: 0.9415 - scaled_graph_loss: 0.3030
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8592 - accuracy: 0.9332 - scaled_graph_loss: 0.2927
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8503 - accuracy: 0.9397 - scaled_graph_loss: 0.2927
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8434 - accuracy: 0.9462 - scaled_graph_loss: 0.2937
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8578 - accuracy: 0.9374 - scaled_graph_loss: 0.3064
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8504 - accuracy: 0.9411 - scaled_graph_loss: 0.3043
<keras.callbacks.History at 0x7f70041be650>

Valuta il modello MLP con la regolarizzazione del grafico

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8884 - accuracy: 0.7957


Eval accuracy for  MLP + graph regularization :  0.7956600189208984
Eval loss for  MLP + graph regularization :  0.8883611559867859

Accuratezza del modello grafico-regolarizzata è circa 2-3% superiore a quella del modello di base ( base_model ).

Conclusione

Abbiamo dimostrato l'uso della regolarizzazione del grafico per la classificazione dei documenti su un grafico delle citazioni naturali (Cora) utilizzando il framework Neural Structured Learning (NSL). Il nostro tutorial avanzato coinvolge la sintesi grafici basati su incastri di esempio prima di allenare una rete neurale con il grafico di regolarizzazione. Questo approccio è utile se l'input non contiene un grafico esplicito.

Incoraggiamo gli utenti a sperimentare ulteriormente variando la quantità di supervisione e provando diverse architetture neurali per la regolarizzazione dei grafi.