Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Regolarizzazione dei grafici per la classificazione dei documenti utilizzando grafici naturali

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub

Panoramica

La regolarizzazione dei grafi è una tecnica specifica nell'ambito del paradigma più ampio di Neural Graph Learning ( Bui et al., 2018 ). L'idea principale è quella di formare modelli di reti neurali con un obiettivo graficamente regolarizzato, sfruttando sia i dati etichettati che quelli senza etichetta.

In questo tutorial, esploreremo l'uso della regolarizzazione dei grafici per classificare i documenti che formano un grafico naturale (organico).

La ricetta generale per creare un modello graficamente regolarizzato usando il framework Neural Structured Learning (NSL) è la seguente:

  1. Generare i dati di allenamento dal grafico di input e dalle caratteristiche di esempio. I nodi nel grafico corrispondono ai campioni e i bordi nel grafico corrispondono alla somiglianza tra coppie di campioni. I dati di addestramento risultanti conterranno le funzioni vicine oltre alle funzioni del nodo originale.
  2. Crea una rete neurale come modello base usando l'API sequenziale, funzionale o della sottoclasse di Keras .
  3. Avvolgere il modello base con la classe wrapper GraphRegularization , fornita dal framework NSL, per creare un nuovo modello Keras grafico. Questo nuovo modello includerà una perdita di regolarizzazione del grafico come termine di regolarizzazione nel suo obiettivo di formazione.
  4. Addestra e valuta il modello grafico di Keras .

Impostare

Installa il pacchetto Neural Structured Learning.

pip install --quiet neural-structured-learning

Dipendenze e importazioni

 import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
 
Version:  2.2.0
Eager mode:  True
GPU is NOT AVAILABLE

Set di dati Cora

Il set di dati Cora è un grafico di citazione in cui i nodi rappresentano documenti di apprendimento automatico e gli spigoli rappresentano citazioni tra coppie di documenti. Il compito coinvolto è la classificazione dei documenti in cui l'obiettivo è quello di classificare ciascun documento in una delle 7 categorie. In altre parole, questo è un problema di classificazione multi-classe con 7 classi.

Grafico

Il grafico originale è diretto. Tuttavia, ai fini di questo esempio, consideriamo la versione non indirizzata di questo grafico. Quindi, se la carta A cita la carta B, consideriamo anche la carta B che ha citato A. Sebbene ciò non sia necessariamente vero, in questo esempio, consideriamo le citazioni come una procura per la somiglianza, che di solito è una proprietà commutativa.

Caratteristiche

Ogni documento nell'input contiene effettivamente 2 funzioni:

  1. Parole : una rappresentazione densa e multiforme del testo nel documento. Il vocabolario per il set di dati Cora contiene 1433 parole uniche. Quindi, la lunghezza di questa funzione è 1433 e il valore nella posizione 'i' è 0/1 che indica se la parola 'i' nel vocabolario esiste o meno nel documento dato.

  2. Etichetta : un singolo numero intero che rappresenta l'ID classe (categoria) del documento.

Scarica il set di dati Cora

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

Converti i dati Cora nel formato NSL

Al fine di preelaborare il set di dati Cora e convertirlo nel formato richiesto da Neural Structured Learning, eseguiremo lo script 'preprocess_cora_dataset.py' , che è incluso nel repository github NSL. Questo script esegue le seguenti operazioni:

  1. Generare funzioni adiacenti utilizzando le funzionalità del nodo originale e il grafico.
  2. Generare treni e testare divisioni di dati contenenti istanze tf.train.Example .
  3. Persistere il treno risultante e testare i dati nel formato TFRecord .
 !wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
 
--2020-07-01 11:15:33--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.128.133, 151.101.64.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0s      

2020-07-01 11:15:33 (84.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.06 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.38 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.04 minutes.

Variabili globali

I percorsi dei file per il treno e i dati di test si basano sui valori dei flag della riga di comando utilizzati per richiamare lo script 'preprocess_cora_dataset.py' sopra.

 ### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
 

iperparametri

Useremo un'istanza di HParams per includere vari iperparametri e costanti usati per l'addestramento e la valutazione. Descriviamo brevemente ciascuno di essi di seguito:

  • num_classes : ci sono un totale di 7 diverse classi

  • max_seq_length : questa è la dimensione del vocabolario e tutte le istanze nell'input hanno una densa rappresentazione multi-hot, bag-of-word. In altre parole, un valore 1 per una parola indica che la parola è presente nell'input e un valore 0 indica che non lo è.

  • distance_type : questa è la metrica della distanza utilizzata per regolarizzare il campione con i suoi vicini.

  • graph_regularization_multiplier : controlla il peso relativo del termine di regolarizzazione del grafico nella funzione di perdita complessiva.

  • num_neighbors : il numero di vicini utilizzati per la regolarizzazione dei grafici. Questo valore deve essere minore o uguale max_nbrs della max_nbrs comando max_nbrs usato sopra quando si esegue preprocess_cora_dataset.py .

  • num_fc_units : il numero di layer completamente connessi nella nostra rete neurale.

  • train_epochs : il numero di epoche di allenamento.

  • batch_size : dimensione del lotto utilizzata per l'addestramento e la valutazione.

  • dropout_rate : controlla la velocità di dropout dopo ogni layer completamente connesso

  • eval_steps : il numero di batch da elaborare prima di ritenere che la valutazione sia completa. Se impostato su None , vengono valutate tutte le istanze nel set di test.

 class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()
 

Carica dati treno e test

Come descritto in precedenza in questo notebook, i dati di addestramento e test di input sono stati creati da "preprocess_cora_dataset.py" . Li tf.data.Dataset in due oggetti tf.data.Dataset : uno per il treno e uno per il test.

Nel livello di input del nostro modello, estrarremo non solo le funzioni "parole" e "etichetta" da ciascun campione, ma anche le corrispondenti funzioni vicine basate sul valore hparams.num_neighbors . hparams.num_neighbors istanze con un numero di vicini inferiore a quello di hparams.num_neighbors verranno assegnati valori fittizi per le caratteristiche di un vicino inesistente.

 def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
 

Diamo un'occhiata al set di dati del treno per esaminarne il contenuto.

 for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
 
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[4 3 1 2 1 6 2 5 6 2 2 6 5 0 2 2 1 6 2 2 2 2 5 4 2 0 2 1 1 2 0 5 2 2 2 0 2
 2 0 6 1 1 0 2 1 2 3 2 0 0 0 4 1 3 3 1 2 5 3 3 1 1 6 0 0 4 6 5 6 0 3 4 2 2
 2 3 3 2 4 0 2 3 2 2 3 1 2 2 1 0 6 1 2 1 6 2 1 0 4 3 2 5 2 3 1 0 3 4 3 4 1
 0 5 6 4 2 1 1 2 5 3 4 3 1 3 2 6 3], shape=(128,), dtype=int64)

Diamo un'occhiata al set di dati di test per esaminarne il contenuto.

 for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
 
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

Definizione del modello

Per dimostrare l'uso della regolarizzazione dei grafici, costruiamo prima un modello base per questo problema. Useremo una semplice rete neurale feed-forward con 2 livelli nascosti e dropout in mezzo. tf.Keras la creazione del modello base usando tutti i tipi di modello supportati dal framework tf.Keras : sequenziale, funzionale e sottoclasse.

Modello base sequenziale

 def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
  return model
 

Modello base funzionale

 def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(
      hparams.num_classes, activation='softmax')(
          cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model
 

Modello base sottoclasse

 def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(
          hparams.num_classes, activation='softmax')

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()
 

Crea modello / i di base

 # Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
 
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
words (InputLayer)           [(None, 1433)]            0         
_________________________________________________________________
lambda (Lambda)              (None, 1433)              0         
_________________________________________________________________
dense (Dense)                (None, 50)                71700     
_________________________________________________________________
dropout (Dropout)            (None, 50)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 50)                2550      
_________________________________________________________________
dropout_1 (Dropout)          (None, 50)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 7)                 357       
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

Base del treno modello MLP

 # Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
 
Epoch 1/100
17/17 [==============================] - 0s 11ms/step - loss: 1.9256 - accuracy: 0.1870
Epoch 2/100
17/17 [==============================] - 0s 10ms/step - loss: 1.8410 - accuracy: 0.2835
Epoch 3/100
17/17 [==============================] - 0s 9ms/step - loss: 1.7479 - accuracy: 0.3374
Epoch 4/100
17/17 [==============================] - 0s 10ms/step - loss: 1.6384 - accuracy: 0.3884
Epoch 5/100
17/17 [==============================] - 0s 9ms/step - loss: 1.5086 - accuracy: 0.4390
Epoch 6/100
17/17 [==============================] - 0s 10ms/step - loss: 1.3606 - accuracy: 0.5016
Epoch 7/100
17/17 [==============================] - 0s 9ms/step - loss: 1.2165 - accuracy: 0.5791
Epoch 8/100
17/17 [==============================] - 0s 10ms/step - loss: 1.0783 - accuracy: 0.6311
Epoch 9/100
17/17 [==============================] - 0s 9ms/step - loss: 0.9552 - accuracy: 0.6947
Epoch 10/100
17/17 [==============================] - 0s 9ms/step - loss: 0.8680 - accuracy: 0.7090
Epoch 11/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7915 - accuracy: 0.7425
Epoch 12/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7124 - accuracy: 0.7773
Epoch 13/100
17/17 [==============================] - 0s 9ms/step - loss: 0.6582 - accuracy: 0.7907
Epoch 14/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6021 - accuracy: 0.8065
Epoch 15/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5416 - accuracy: 0.8325
Epoch 16/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5042 - accuracy: 0.8473
Epoch 17/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4433 - accuracy: 0.8761
Epoch 18/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4310 - accuracy: 0.8640
Epoch 19/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3894 - accuracy: 0.8840
Epoch 20/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3676 - accuracy: 0.8891
Epoch 21/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3576 - accuracy: 0.8812
Epoch 22/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3132 - accuracy: 0.9067
Epoch 23/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3058 - accuracy: 0.9142
Epoch 24/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2924 - accuracy: 0.9155
Epoch 25/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2769 - accuracy: 0.9197
Epoch 26/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2636 - accuracy: 0.9244
Epoch 27/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2429 - accuracy: 0.9313
Epoch 28/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2324 - accuracy: 0.9323
Epoch 29/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2285 - accuracy: 0.9346
Epoch 30/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2039 - accuracy: 0.9374
Epoch 31/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1943 - accuracy: 0.9471
Epoch 32/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1898 - accuracy: 0.9439
Epoch 33/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1879 - accuracy: 0.9425
Epoch 34/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1828 - accuracy: 0.9443
Epoch 35/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1635 - accuracy: 0.9541
Epoch 36/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1648 - accuracy: 0.9476
Epoch 37/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1603 - accuracy: 0.9499
Epoch 38/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1428 - accuracy: 0.9624
Epoch 39/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1483 - accuracy: 0.9601
Epoch 40/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1352 - accuracy: 0.9582
Epoch 41/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1379 - accuracy: 0.9555
Epoch 42/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1410 - accuracy: 0.9582
Epoch 43/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1198 - accuracy: 0.9684
Epoch 44/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1148 - accuracy: 0.9731
Epoch 45/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1228 - accuracy: 0.9657
Epoch 46/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1135 - accuracy: 0.9703
Epoch 47/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1134 - accuracy: 0.9661
Epoch 48/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1175 - accuracy: 0.9619
Epoch 49/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1002 - accuracy: 0.9703
Epoch 50/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1143 - accuracy: 0.9671
Epoch 51/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0923 - accuracy: 0.9777
Epoch 52/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1068 - accuracy: 0.9731
Epoch 53/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0972 - accuracy: 0.9712
Epoch 54/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0828 - accuracy: 0.9796
Epoch 55/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1036 - accuracy: 0.9703
Epoch 56/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0954 - accuracy: 0.9745
Epoch 57/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0883 - accuracy: 0.9768
Epoch 58/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0859 - accuracy: 0.9777
Epoch 59/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0856 - accuracy: 0.9759
Epoch 60/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0858 - accuracy: 0.9754
Epoch 61/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0848 - accuracy: 0.9726
Epoch 62/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0840 - accuracy: 0.9763
Epoch 63/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0770 - accuracy: 0.9805
Epoch 64/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0823 - accuracy: 0.9745
Epoch 65/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0665 - accuracy: 0.9828
Epoch 66/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0788 - accuracy: 0.9777
Epoch 67/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0690 - accuracy: 0.9800
Epoch 68/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0683 - accuracy: 0.9805
Epoch 69/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0615 - accuracy: 0.9838
Epoch 70/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0618 - accuracy: 0.9833
Epoch 71/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0659 - accuracy: 0.9810
Epoch 72/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0704 - accuracy: 0.9800
Epoch 73/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0645 - accuracy: 0.9814
Epoch 74/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0645 - accuracy: 0.9791
Epoch 75/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0638 - accuracy: 0.9791
Epoch 76/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0648 - accuracy: 0.9814
Epoch 77/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0591 - accuracy: 0.9838
Epoch 78/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0606 - accuracy: 0.9861
Epoch 79/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0699 - accuracy: 0.9814
Epoch 80/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0603 - accuracy: 0.9828
Epoch 81/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0629 - accuracy: 0.9828
Epoch 82/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0596 - accuracy: 0.9828
Epoch 83/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0542 - accuracy: 0.9828
Epoch 84/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0452 - accuracy: 0.9893
Epoch 85/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0551 - accuracy: 0.9838
Epoch 86/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0555 - accuracy: 0.9842
Epoch 87/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0514 - accuracy: 0.9824
Epoch 88/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0553 - accuracy: 0.9847
Epoch 89/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0475 - accuracy: 0.9884
Epoch 90/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0476 - accuracy: 0.9893
Epoch 91/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0427 - accuracy: 0.9903
Epoch 92/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0475 - accuracy: 0.9847
Epoch 93/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0423 - accuracy: 0.9893
Epoch 94/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0473 - accuracy: 0.9865
Epoch 95/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0560 - accuracy: 0.9819
Epoch 96/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0547 - accuracy: 0.9810
Epoch 97/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0576 - accuracy: 0.9814
Epoch 98/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0429 - accuracy: 0.9893
Epoch 99/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0440 - accuracy: 0.9875
Epoch 100/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0513 - accuracy: 0.9838

<tensorflow.python.keras.callbacks.History at 0x7fc47a3c78d0>

Valuta il modello MLP di base

 # Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
 
 eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
 
5/5 [==============================] - 0s 5ms/step - loss: 1.3380 - accuracy: 0.7740


Eval accuracy for  Base MLP model :  0.7739602327346802
Eval loss for  Base MLP model :  1.3379606008529663

Modello di treno MLP con regolarizzazione grafica

Incorporare la regolarizzazione dei grafici nel termine di perdita di un tf.Keras.Model esistente richiede solo poche righe di codice. Il modello di base viene spostato per creare un nuovo modello di sottoclasse tf.Keras , la cui perdita include la regolarizzazione del grafico.

Per valutare il vantaggio incrementale della regolarizzazione dei grafici, creeremo una nuova istanza del modello di base. Questo perché base_model è già stato addestrato per alcune iterazioni e riutilizzare questo modello addestrato per creare un modello con regolarizzazione grafica non sarà un confronto equo per base_model .

 # Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
 
 # Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
 
Epoch 1/100

/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:434: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

17/17 [==============================] - 0s 10ms/step - loss: 1.9454 - accuracy: 0.1652 - graph_loss: 0.0076
Epoch 2/100
17/17 [==============================] - 0s 10ms/step - loss: 1.8517 - accuracy: 0.2956 - graph_loss: 0.0117
Epoch 3/100
17/17 [==============================] - 0s 10ms/step - loss: 1.7589 - accuracy: 0.3151 - graph_loss: 0.0261
Epoch 4/100
17/17 [==============================] - 0s 10ms/step - loss: 1.6714 - accuracy: 0.3392 - graph_loss: 0.0476
Epoch 5/100
17/17 [==============================] - 0s 9ms/step - loss: 1.5607 - accuracy: 0.4037 - graph_loss: 0.0622
Epoch 6/100
17/17 [==============================] - 0s 10ms/step - loss: 1.4486 - accuracy: 0.4807 - graph_loss: 0.0921
Epoch 7/100
17/17 [==============================] - 0s 10ms/step - loss: 1.3135 - accuracy: 0.5383 - graph_loss: 0.1236
Epoch 8/100
17/17 [==============================] - 0s 10ms/step - loss: 1.1902 - accuracy: 0.5912 - graph_loss: 0.1616
Epoch 9/100
17/17 [==============================] - 0s 10ms/step - loss: 1.0647 - accuracy: 0.6575 - graph_loss: 0.1920
Epoch 10/100
17/17 [==============================] - 0s 9ms/step - loss: 0.9416 - accuracy: 0.7067 - graph_loss: 0.2181
Epoch 11/100
17/17 [==============================] - 0s 10ms/step - loss: 0.8601 - accuracy: 0.7378 - graph_loss: 0.2470
Epoch 12/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7968 - accuracy: 0.7462 - graph_loss: 0.2565
Epoch 13/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6881 - accuracy: 0.7912 - graph_loss: 0.2681
Epoch 14/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6548 - accuracy: 0.8139 - graph_loss: 0.2941
Epoch 15/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5874 - accuracy: 0.8376 - graph_loss: 0.3010
Epoch 16/100
17/17 [==============================] - 0s 9ms/step - loss: 0.5537 - accuracy: 0.8348 - graph_loss: 0.3014
Epoch 17/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5123 - accuracy: 0.8529 - graph_loss: 0.3097
Epoch 18/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4771 - accuracy: 0.8640 - graph_loss: 0.3192
Epoch 19/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4294 - accuracy: 0.8826 - graph_loss: 0.3182
Epoch 20/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4109 - accuracy: 0.8854 - graph_loss: 0.3169
Epoch 21/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3901 - accuracy: 0.8965 - graph_loss: 0.3250
Epoch 22/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3700 - accuracy: 0.8956 - graph_loss: 0.3349
Epoch 23/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3716 - accuracy: 0.8974 - graph_loss: 0.3408
Epoch 24/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3258 - accuracy: 0.9202 - graph_loss: 0.3361
Epoch 25/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3043 - accuracy: 0.9253 - graph_loss: 0.3351
Epoch 26/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2919 - accuracy: 0.9253 - graph_loss: 0.3361
Epoch 27/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3005 - accuracy: 0.9202 - graph_loss: 0.3249
Epoch 28/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2629 - accuracy: 0.9336 - graph_loss: 0.3442
Epoch 29/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2617 - accuracy: 0.9401 - graph_loss: 0.3302
Epoch 30/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2510 - accuracy: 0.9383 - graph_loss: 0.3436
Epoch 31/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2452 - accuracy: 0.9411 - graph_loss: 0.3364
Epoch 32/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2397 - accuracy: 0.9466 - graph_loss: 0.3333
Epoch 33/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2239 - accuracy: 0.9466 - graph_loss: 0.3373
Epoch 34/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2084 - accuracy: 0.9513 - graph_loss: 0.3330
Epoch 35/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2075 - accuracy: 0.9499 - graph_loss: 0.3383
Epoch 36/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2064 - accuracy: 0.9513 - graph_loss: 0.3394
Epoch 37/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1857 - accuracy: 0.9568 - graph_loss: 0.3371
Epoch 38/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1799 - accuracy: 0.9601 - graph_loss: 0.3477
Epoch 39/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1844 - accuracy: 0.9573 - graph_loss: 0.3385
Epoch 40/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1823 - accuracy: 0.9592 - graph_loss: 0.3445
Epoch 41/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1713 - accuracy: 0.9615 - graph_loss: 0.3451
Epoch 42/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1669 - accuracy: 0.9624 - graph_loss: 0.3398
Epoch 43/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1692 - accuracy: 0.9671 - graph_loss: 0.3483
Epoch 44/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1605 - accuracy: 0.9647 - graph_loss: 0.3437
Epoch 45/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1485 - accuracy: 0.9703 - graph_loss: 0.3338
Epoch 46/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1467 - accuracy: 0.9717 - graph_loss: 0.3405
Epoch 47/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1492 - accuracy: 0.9694 - graph_loss: 0.3466
Epoch 48/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1577 - accuracy: 0.9666 - graph_loss: 0.3338
Epoch 49/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1363 - accuracy: 0.9773 - graph_loss: 0.3424
Epoch 50/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1511 - accuracy: 0.9694 - graph_loss: 0.3402
Epoch 51/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1366 - accuracy: 0.9759 - graph_loss: 0.3385
Epoch 52/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1254 - accuracy: 0.9777 - graph_loss: 0.3474
Epoch 53/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1289 - accuracy: 0.9740 - graph_loss: 0.3469
Epoch 54/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1410 - accuracy: 0.9689 - graph_loss: 0.3475
Epoch 55/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1356 - accuracy: 0.9703 - graph_loss: 0.3483
Epoch 56/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1283 - accuracy: 0.9773 - graph_loss: 0.3412
Epoch 57/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1264 - accuracy: 0.9745 - graph_loss: 0.3473
Epoch 58/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1242 - accuracy: 0.9740 - graph_loss: 0.3443
Epoch 59/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1144 - accuracy: 0.9782 - graph_loss: 0.3440
Epoch 60/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1250 - accuracy: 0.9735 - graph_loss: 0.3357
Epoch 61/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1190 - accuracy: 0.9787 - graph_loss: 0.3400
Epoch 62/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1141 - accuracy: 0.9814 - graph_loss: 0.3419
Epoch 63/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1085 - accuracy: 0.9787 - graph_loss: 0.3395
Epoch 64/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1148 - accuracy: 0.9768 - graph_loss: 0.3504
Epoch 65/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1137 - accuracy: 0.9791 - graph_loss: 0.3360
Epoch 66/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1121 - accuracy: 0.9745 - graph_loss: 0.3469
Epoch 67/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1046 - accuracy: 0.9810 - graph_loss: 0.3476
Epoch 68/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1112 - accuracy: 0.9791 - graph_loss: 0.3431
Epoch 69/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1075 - accuracy: 0.9787 - graph_loss: 0.3455
Epoch 70/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0986 - accuracy: 0.9875 - graph_loss: 0.3403
Epoch 71/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1141 - accuracy: 0.9782 - graph_loss: 0.3508
Epoch 72/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1012 - accuracy: 0.9814 - graph_loss: 0.3453
Epoch 73/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9833 - graph_loss: 0.3430
Epoch 74/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9842 - graph_loss: 0.3447
Epoch 75/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0988 - accuracy: 0.9842 - graph_loss: 0.3430
Epoch 76/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0915 - accuracy: 0.9856 - graph_loss: 0.3475
Epoch 77/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0960 - accuracy: 0.9833 - graph_loss: 0.3353
Epoch 78/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0916 - accuracy: 0.9838 - graph_loss: 0.3441
Epoch 79/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0979 - accuracy: 0.9800 - graph_loss: 0.3476
Epoch 80/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9782 - graph_loss: 0.3400
Epoch 81/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0978 - accuracy: 0.9838 - graph_loss: 0.3386
Epoch 82/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9805 - graph_loss: 0.3416
Epoch 83/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0957 - accuracy: 0.9838 - graph_loss: 0.3398
Epoch 84/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9879 - graph_loss: 0.3379
Epoch 85/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9838 - graph_loss: 0.3441
Epoch 86/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0906 - accuracy: 0.9847 - graph_loss: 0.3445
Epoch 87/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9852 - graph_loss: 0.3506
Epoch 88/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0821 - accuracy: 0.9898 - graph_loss: 0.3448
Epoch 89/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0803 - accuracy: 0.9865 - graph_loss: 0.3370
Epoch 90/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9828 - graph_loss: 0.3428
Epoch 91/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0887 - accuracy: 0.9852 - graph_loss: 0.3505
Epoch 92/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0882 - accuracy: 0.9847 - graph_loss: 0.3396
Epoch 93/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0807 - accuracy: 0.9879 - graph_loss: 0.3473
Epoch 94/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0820 - accuracy: 0.9861 - graph_loss: 0.3367
Epoch 95/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0864 - accuracy: 0.9838 - graph_loss: 0.3353
Epoch 96/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0786 - accuracy: 0.9889 - graph_loss: 0.3392
Epoch 97/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0735 - accuracy: 0.9912 - graph_loss: 0.3443
Epoch 98/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0861 - accuracy: 0.9842 - graph_loss: 0.3381
Epoch 99/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0850 - accuracy: 0.9833 - graph_loss: 0.3376
Epoch 100/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0841 - accuracy: 0.9879 - graph_loss: 0.3510

<tensorflow.python.keras.callbacks.History at 0x7fc3d853ce10>

Valuta il modello MLP con la regolarizzazione dei grafici

 eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
 
5/5 [==============================] - 0s 6ms/step - loss: 1.2475 - accuracy: 0.8192


Eval accuracy for  MLP + graph regularization :  0.8191681504249573
Eval loss for  MLP + graph regularization :  1.2474583387374878

La precisione del modello regolarizzato dal grafico è di circa il 2-3% superiore a quella del modello base ( base_model ).

Conclusione

Abbiamo dimostrato l'uso della regolarizzazione dei grafi per la classificazione dei documenti su un grafico di citazione naturale (Cora) usando il framework Neural Structured Learning (NSL). Il nostro tutorial avanzato prevede la sintesi di grafici basati su esempi di incorporamenti prima di addestrare una rete neurale con regolarizzazione dei grafici. Questo approccio è utile se l'input non contiene un grafico esplicito.

Incoraggiamo gli utenti a sperimentare ulteriormente variando la quantità di supervisione e provando diverse architetture neurali per la regolarizzazione dei grafici.