Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Regularización de gráficos para la clasificación de documentos utilizando gráficos naturales

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub

Visión general

La regularización de gráficos es una técnica específica bajo el paradigma más amplio de aprendizaje de gráficos neuronales ( Bui et al., 2018 ). La idea central es entrenar modelos de redes neuronales con un objetivo de regularización de gráficos, aprovechando datos etiquetados y no etiquetados.

En este tutorial, exploraremos el uso de la regularización de gráficos para clasificar documentos que forman un gráfico natural (orgánico).

La receta general para crear un modelo con regularización de gráficos utilizando el marco de aprendizaje estructurado neuronal (NSL) es la siguiente:

  1. Genere datos de entrenamiento a partir del gráfico de entrada y las funciones de muestra. Los nodos en el gráfico corresponden a muestras y los bordes en el gráfico corresponden a similitudes entre pares de muestras. Los datos de entrenamiento resultantes contendrán características vecinas además de las características originales del nodo.
  2. Cree una red neuronal como modelo base utilizando la API secuencial, funcional o de subclase de Keras .
  3. Envuelva el modelo base con la clase contenedora GraphRegularization , que es proporcionada por el marco NSL, para crear un nuevo modelo gráfico de Keras . Este nuevo modelo incluirá un gráfico de pérdida de regularización como término de regularización en su objetivo de entrenamiento.
  4. Entrenar y evaluar el modelo gráfico de Keras .

Preparar

Instale el paquete Neural Structured Learning.

pip install --quiet neural-structured-learning

Dependencias e importaciones

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version:  2.2.0
Eager mode:  True
GPU is NOT AVAILABLE

Conjunto de datos de Cora

El conjunto de datos de Cora es un gráfico de citas donde los nodos representan artículos de aprendizaje automático y los bordes representan citas entre pares de artículos. La tarea involucrada es la clasificación de documentos donde el objetivo es categorizar cada artículo en una de 7 categorías. En otras palabras, este es un problema de clasificación de clases múltiples con 7 clases.

Grafico

La gráfica original está dirigida. Sin embargo, a los efectos de este ejemplo, consideramos la versión no dirigida de este gráfico. Entonces, si el artículo A cita el artículo B, también consideramos que el artículo B ha citado A. Aunque esto no es necesariamente cierto, en este ejemplo, consideramos las citas como un sustituto de la similitud, que generalmente es una propiedad conmutativa.

Caracteristicas

Cada papel en la entrada contiene efectivamente 2 características:

  1. Palabras : Una representación densa, con varias palabras en caliente del texto en el documento. El vocabulario del conjunto de datos Cora contiene 1433 palabras únicas. Entonces, la longitud de esta característica es 1433, y el valor en la posición 'i' es 0/1, lo que indica si la palabra 'i' en el vocabulario existe en el documento dado o no.

  2. Etiqueta : un número entero que representa el ID de clase (categoría) del artículo.

Descarga el conjunto de datos de Cora

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

Convierta los datos de Cora al formato NSL

Para preprocesar el conjunto de datos Cora y convertirlo al formato requerido por Neural Structured Learning, ejecutaremos el script 'preprocess_cora_dataset.py' , que se incluye en el repositorio github de NSL. Este script hace lo siguiente:

  1. Genere entidades vecinas utilizando las entidades de nodo originales y el gráfico.
  2. Genere divisiones de datos de prueba y tren que tf.train.Example instancias de tf.train.Example .
  3. Conservar el tren resultante y los datos de prueba en formato TFRecord .
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2020-07-01 11:15:33--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.128.133, 151.101.64.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0s      

2020-07-01 11:15:33 (84.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.06 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.38 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.04 minutes.

Variables globales

Las rutas de archivo al tren y a los datos de prueba se basan en los valores de la marca de línea de comando que se utilizan para invocar el script 'preprocess_cora_dataset.py' anterior.

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

Hiperparámetros

Usaremos una instancia de HParams para incluir varios hiperparámetros y constantes usados ​​para entrenamiento y evaluación. A continuación, describimos brevemente cada uno de ellos:

  • num_classes : Hay un total de 7 clases diferentes

  • max_seq_length : Este es el tamaño del vocabulario y todas las instancias en la entrada tienen una representación densa de múltiples palabras calientes. En otras palabras, un valor de 1 para una palabra indica que la palabra está presente en la entrada y un valor de 0 indica que no lo está.

  • distance_type : esta es la métrica de distancia utilizada para regularizar la muestra con sus vecinos.

  • graph_regularization_multiplier : Controla el peso relativo del término de regularización del gráfico en la función de pérdida general.

  • num_neighbors : el número de vecinos utilizados para la regularización del gráfico. Este valor tiene que ser menor o igual que el argumento de línea de comando max_nbrs usado anteriormente cuando se ejecuta preprocess_cora_dataset.py .

  • num_fc_units : el número de capas completamente conectadas en nuestra red neuronal.

  • train_epochs : el número de épocas de entrenamiento.

  • batch_size : tamaño de lote usado para entrenamiento y evaluación.

  • dropout_rate : controla la tasa de abandono después de cada capa completamente conectada

  • eval_steps : el número de lotes a procesar antes de considerar que la evaluación está completa. Si se establece en None , se evalúan todas las instancias del conjunto de prueba.

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

Tren de carga y datos de prueba

Como se describió anteriormente en este cuaderno, los datos de prueba y entrenamiento de entrada han sido creados por 'preprocess_cora_dataset.py' . Los tf.data.Dataset en dos objetos tf.data.Dataset : uno para entrenar y otro para prueba.

En la capa de entrada de nuestro modelo, extraeremos no solo las 'palabras' y las características de la 'etiqueta' de cada muestra, sino también las características vecinas correspondientes basadas en el valor hparams.num_neighbors . A las instancias con menos vecinos que hparams.num_neighbors se les asignarán valores ficticios para esas características vecinas no existentes.

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

Echemos un vistazo al conjunto de datos del tren para ver su contenido.

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[4 3 1 2 1 6 2 5 6 2 2 6 5 0 2 2 1 6 2 2 2 2 5 4 2 0 2 1 1 2 0 5 2 2 2 0 2
 2 0 6 1 1 0 2 1 2 3 2 0 0 0 4 1 3 3 1 2 5 3 3 1 1 6 0 0 4 6 5 6 0 3 4 2 2
 2 3 3 2 4 0 2 3 2 2 3 1 2 2 1 0 6 1 2 1 6 2 1 0 4 3 2 5 2 3 1 0 3 4 3 4 1
 0 5 6 4 2 1 1 2 5 3 4 3 1 3 2 6 3], shape=(128,), dtype=int64)

Echemos un vistazo al conjunto de datos de prueba para ver su contenido.

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

Definición de modelo

Para demostrar el uso de la regularización de gráficos, primero construimos un modelo base para este problema. Usaremos una red neuronal de avance simple con 2 capas ocultas y abandono en el medio. Ilustramos la creación del modelo base utilizando todos los tipos de modelos compatibles con el marco tf.Keras : secuencial, funcional y subclase.

Modelo base secuencial

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
  return model

Modelo básico funcional

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(
      hparams.num_classes, activation='softmax')(
          cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

Modelo base de subclase

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(
          hparams.num_classes, activation='softmax')

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

Crear modelo (s) base

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
words (InputLayer)           [(None, 1433)]            0         
_________________________________________________________________
lambda (Lambda)              (None, 1433)              0         
_________________________________________________________________
dense (Dense)                (None, 50)                71700     
_________________________________________________________________
dropout (Dropout)            (None, 50)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 50)                2550      
_________________________________________________________________
dropout_1 (Dropout)          (None, 50)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 7)                 357       
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

Modelo MLP base de tren

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
17/17 [==============================] - 0s 11ms/step - loss: 1.9256 - accuracy: 0.1870
Epoch 2/100
17/17 [==============================] - 0s 10ms/step - loss: 1.8410 - accuracy: 0.2835
Epoch 3/100
17/17 [==============================] - 0s 9ms/step - loss: 1.7479 - accuracy: 0.3374
Epoch 4/100
17/17 [==============================] - 0s 10ms/step - loss: 1.6384 - accuracy: 0.3884
Epoch 5/100
17/17 [==============================] - 0s 9ms/step - loss: 1.5086 - accuracy: 0.4390
Epoch 6/100
17/17 [==============================] - 0s 10ms/step - loss: 1.3606 - accuracy: 0.5016
Epoch 7/100
17/17 [==============================] - 0s 9ms/step - loss: 1.2165 - accuracy: 0.5791
Epoch 8/100
17/17 [==============================] - 0s 10ms/step - loss: 1.0783 - accuracy: 0.6311
Epoch 9/100
17/17 [==============================] - 0s 9ms/step - loss: 0.9552 - accuracy: 0.6947
Epoch 10/100
17/17 [==============================] - 0s 9ms/step - loss: 0.8680 - accuracy: 0.7090
Epoch 11/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7915 - accuracy: 0.7425
Epoch 12/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7124 - accuracy: 0.7773
Epoch 13/100
17/17 [==============================] - 0s 9ms/step - loss: 0.6582 - accuracy: 0.7907
Epoch 14/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6021 - accuracy: 0.8065
Epoch 15/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5416 - accuracy: 0.8325
Epoch 16/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5042 - accuracy: 0.8473
Epoch 17/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4433 - accuracy: 0.8761
Epoch 18/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4310 - accuracy: 0.8640
Epoch 19/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3894 - accuracy: 0.8840
Epoch 20/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3676 - accuracy: 0.8891
Epoch 21/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3576 - accuracy: 0.8812
Epoch 22/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3132 - accuracy: 0.9067
Epoch 23/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3058 - accuracy: 0.9142
Epoch 24/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2924 - accuracy: 0.9155
Epoch 25/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2769 - accuracy: 0.9197
Epoch 26/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2636 - accuracy: 0.9244
Epoch 27/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2429 - accuracy: 0.9313
Epoch 28/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2324 - accuracy: 0.9323
Epoch 29/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2285 - accuracy: 0.9346
Epoch 30/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2039 - accuracy: 0.9374
Epoch 31/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1943 - accuracy: 0.9471
Epoch 32/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1898 - accuracy: 0.9439
Epoch 33/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1879 - accuracy: 0.9425
Epoch 34/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1828 - accuracy: 0.9443
Epoch 35/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1635 - accuracy: 0.9541
Epoch 36/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1648 - accuracy: 0.9476
Epoch 37/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1603 - accuracy: 0.9499
Epoch 38/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1428 - accuracy: 0.9624
Epoch 39/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1483 - accuracy: 0.9601
Epoch 40/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1352 - accuracy: 0.9582
Epoch 41/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1379 - accuracy: 0.9555
Epoch 42/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1410 - accuracy: 0.9582
Epoch 43/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1198 - accuracy: 0.9684
Epoch 44/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1148 - accuracy: 0.9731
Epoch 45/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1228 - accuracy: 0.9657
Epoch 46/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1135 - accuracy: 0.9703
Epoch 47/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1134 - accuracy: 0.9661
Epoch 48/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1175 - accuracy: 0.9619
Epoch 49/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1002 - accuracy: 0.9703
Epoch 50/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1143 - accuracy: 0.9671
Epoch 51/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0923 - accuracy: 0.9777
Epoch 52/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1068 - accuracy: 0.9731
Epoch 53/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0972 - accuracy: 0.9712
Epoch 54/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0828 - accuracy: 0.9796
Epoch 55/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1036 - accuracy: 0.9703
Epoch 56/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0954 - accuracy: 0.9745
Epoch 57/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0883 - accuracy: 0.9768
Epoch 58/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0859 - accuracy: 0.9777
Epoch 59/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0856 - accuracy: 0.9759
Epoch 60/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0858 - accuracy: 0.9754
Epoch 61/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0848 - accuracy: 0.9726
Epoch 62/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0840 - accuracy: 0.9763
Epoch 63/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0770 - accuracy: 0.9805
Epoch 64/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0823 - accuracy: 0.9745
Epoch 65/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0665 - accuracy: 0.9828
Epoch 66/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0788 - accuracy: 0.9777
Epoch 67/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0690 - accuracy: 0.9800
Epoch 68/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0683 - accuracy: 0.9805
Epoch 69/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0615 - accuracy: 0.9838
Epoch 70/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0618 - accuracy: 0.9833
Epoch 71/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0659 - accuracy: 0.9810
Epoch 72/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0704 - accuracy: 0.9800
Epoch 73/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0645 - accuracy: 0.9814
Epoch 74/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0645 - accuracy: 0.9791
Epoch 75/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0638 - accuracy: 0.9791
Epoch 76/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0648 - accuracy: 0.9814
Epoch 77/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0591 - accuracy: 0.9838
Epoch 78/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0606 - accuracy: 0.9861
Epoch 79/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0699 - accuracy: 0.9814
Epoch 80/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0603 - accuracy: 0.9828
Epoch 81/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0629 - accuracy: 0.9828
Epoch 82/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0596 - accuracy: 0.9828
Epoch 83/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0542 - accuracy: 0.9828
Epoch 84/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0452 - accuracy: 0.9893
Epoch 85/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0551 - accuracy: 0.9838
Epoch 86/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0555 - accuracy: 0.9842
Epoch 87/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0514 - accuracy: 0.9824
Epoch 88/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0553 - accuracy: 0.9847
Epoch 89/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0475 - accuracy: 0.9884
Epoch 90/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0476 - accuracy: 0.9893
Epoch 91/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0427 - accuracy: 0.9903
Epoch 92/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0475 - accuracy: 0.9847
Epoch 93/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0423 - accuracy: 0.9893
Epoch 94/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0473 - accuracy: 0.9865
Epoch 95/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0560 - accuracy: 0.9819
Epoch 96/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0547 - accuracy: 0.9810
Epoch 97/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0576 - accuracy: 0.9814
Epoch 98/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0429 - accuracy: 0.9893
Epoch 99/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0440 - accuracy: 0.9875
Epoch 100/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0513 - accuracy: 0.9838

<tensorflow.python.keras.callbacks.History at 0x7fc47a3c78d0>

Evaluar el modelo MLP base

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.3380 - accuracy: 0.7740


Eval accuracy for  Base MLP model :  0.7739602327346802
Eval loss for  Base MLP model :  1.3379606008529663

Entrene el modelo MLP con regularización de gráficos

La incorporación de la regularización de gráficos en el término de pérdida de un tf.Keras.Model existente requiere solo unas pocas líneas de código. El modelo base se tf.Keras para crear un nuevo modelo de subclase tf.Keras , cuya pérdida incluye la regularización de gráficos.

Para evaluar el beneficio incremental de la regularización de gráficos, crearemos una nueva instancia de modelo base. Esto se debe a que base_model ya ha sido entrenado para algunas iteraciones, y reutilizar este modelo entrenado para crear un modelo regularizado por gráficos no será una comparación justa para base_model .

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100

/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:434: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

17/17 [==============================] - 0s 10ms/step - loss: 1.9454 - accuracy: 0.1652 - graph_loss: 0.0076
Epoch 2/100
17/17 [==============================] - 0s 10ms/step - loss: 1.8517 - accuracy: 0.2956 - graph_loss: 0.0117
Epoch 3/100
17/17 [==============================] - 0s 10ms/step - loss: 1.7589 - accuracy: 0.3151 - graph_loss: 0.0261
Epoch 4/100
17/17 [==============================] - 0s 10ms/step - loss: 1.6714 - accuracy: 0.3392 - graph_loss: 0.0476
Epoch 5/100
17/17 [==============================] - 0s 9ms/step - loss: 1.5607 - accuracy: 0.4037 - graph_loss: 0.0622
Epoch 6/100
17/17 [==============================] - 0s 10ms/step - loss: 1.4486 - accuracy: 0.4807 - graph_loss: 0.0921
Epoch 7/100
17/17 [==============================] - 0s 10ms/step - loss: 1.3135 - accuracy: 0.5383 - graph_loss: 0.1236
Epoch 8/100
17/17 [==============================] - 0s 10ms/step - loss: 1.1902 - accuracy: 0.5912 - graph_loss: 0.1616
Epoch 9/100
17/17 [==============================] - 0s 10ms/step - loss: 1.0647 - accuracy: 0.6575 - graph_loss: 0.1920
Epoch 10/100
17/17 [==============================] - 0s 9ms/step - loss: 0.9416 - accuracy: 0.7067 - graph_loss: 0.2181
Epoch 11/100
17/17 [==============================] - 0s 10ms/step - loss: 0.8601 - accuracy: 0.7378 - graph_loss: 0.2470
Epoch 12/100
17/17 [==============================] - 0s 9ms/step - loss: 0.7968 - accuracy: 0.7462 - graph_loss: 0.2565
Epoch 13/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6881 - accuracy: 0.7912 - graph_loss: 0.2681
Epoch 14/100
17/17 [==============================] - 0s 10ms/step - loss: 0.6548 - accuracy: 0.8139 - graph_loss: 0.2941
Epoch 15/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5874 - accuracy: 0.8376 - graph_loss: 0.3010
Epoch 16/100
17/17 [==============================] - 0s 9ms/step - loss: 0.5537 - accuracy: 0.8348 - graph_loss: 0.3014
Epoch 17/100
17/17 [==============================] - 0s 10ms/step - loss: 0.5123 - accuracy: 0.8529 - graph_loss: 0.3097
Epoch 18/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4771 - accuracy: 0.8640 - graph_loss: 0.3192
Epoch 19/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4294 - accuracy: 0.8826 - graph_loss: 0.3182
Epoch 20/100
17/17 [==============================] - 0s 10ms/step - loss: 0.4109 - accuracy: 0.8854 - graph_loss: 0.3169
Epoch 21/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3901 - accuracy: 0.8965 - graph_loss: 0.3250
Epoch 22/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3700 - accuracy: 0.8956 - graph_loss: 0.3349
Epoch 23/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3716 - accuracy: 0.8974 - graph_loss: 0.3408
Epoch 24/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3258 - accuracy: 0.9202 - graph_loss: 0.3361
Epoch 25/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3043 - accuracy: 0.9253 - graph_loss: 0.3351
Epoch 26/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2919 - accuracy: 0.9253 - graph_loss: 0.3361
Epoch 27/100
17/17 [==============================] - 0s 10ms/step - loss: 0.3005 - accuracy: 0.9202 - graph_loss: 0.3249
Epoch 28/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2629 - accuracy: 0.9336 - graph_loss: 0.3442
Epoch 29/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2617 - accuracy: 0.9401 - graph_loss: 0.3302
Epoch 30/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2510 - accuracy: 0.9383 - graph_loss: 0.3436
Epoch 31/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2452 - accuracy: 0.9411 - graph_loss: 0.3364
Epoch 32/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2397 - accuracy: 0.9466 - graph_loss: 0.3333
Epoch 33/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2239 - accuracy: 0.9466 - graph_loss: 0.3373
Epoch 34/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2084 - accuracy: 0.9513 - graph_loss: 0.3330
Epoch 35/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2075 - accuracy: 0.9499 - graph_loss: 0.3383
Epoch 36/100
17/17 [==============================] - 0s 10ms/step - loss: 0.2064 - accuracy: 0.9513 - graph_loss: 0.3394
Epoch 37/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1857 - accuracy: 0.9568 - graph_loss: 0.3371
Epoch 38/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1799 - accuracy: 0.9601 - graph_loss: 0.3477
Epoch 39/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1844 - accuracy: 0.9573 - graph_loss: 0.3385
Epoch 40/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1823 - accuracy: 0.9592 - graph_loss: 0.3445
Epoch 41/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1713 - accuracy: 0.9615 - graph_loss: 0.3451
Epoch 42/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1669 - accuracy: 0.9624 - graph_loss: 0.3398
Epoch 43/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1692 - accuracy: 0.9671 - graph_loss: 0.3483
Epoch 44/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1605 - accuracy: 0.9647 - graph_loss: 0.3437
Epoch 45/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1485 - accuracy: 0.9703 - graph_loss: 0.3338
Epoch 46/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1467 - accuracy: 0.9717 - graph_loss: 0.3405
Epoch 47/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1492 - accuracy: 0.9694 - graph_loss: 0.3466
Epoch 48/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1577 - accuracy: 0.9666 - graph_loss: 0.3338
Epoch 49/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1363 - accuracy: 0.9773 - graph_loss: 0.3424
Epoch 50/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1511 - accuracy: 0.9694 - graph_loss: 0.3402
Epoch 51/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1366 - accuracy: 0.9759 - graph_loss: 0.3385
Epoch 52/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1254 - accuracy: 0.9777 - graph_loss: 0.3474
Epoch 53/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1289 - accuracy: 0.9740 - graph_loss: 0.3469
Epoch 54/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1410 - accuracy: 0.9689 - graph_loss: 0.3475
Epoch 55/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1356 - accuracy: 0.9703 - graph_loss: 0.3483
Epoch 56/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1283 - accuracy: 0.9773 - graph_loss: 0.3412
Epoch 57/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1264 - accuracy: 0.9745 - graph_loss: 0.3473
Epoch 58/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1242 - accuracy: 0.9740 - graph_loss: 0.3443
Epoch 59/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1144 - accuracy: 0.9782 - graph_loss: 0.3440
Epoch 60/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1250 - accuracy: 0.9735 - graph_loss: 0.3357
Epoch 61/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1190 - accuracy: 0.9787 - graph_loss: 0.3400
Epoch 62/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1141 - accuracy: 0.9814 - graph_loss: 0.3419
Epoch 63/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1085 - accuracy: 0.9787 - graph_loss: 0.3395
Epoch 64/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1148 - accuracy: 0.9768 - graph_loss: 0.3504
Epoch 65/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1137 - accuracy: 0.9791 - graph_loss: 0.3360
Epoch 66/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1121 - accuracy: 0.9745 - graph_loss: 0.3469
Epoch 67/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1046 - accuracy: 0.9810 - graph_loss: 0.3476
Epoch 68/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1112 - accuracy: 0.9791 - graph_loss: 0.3431
Epoch 69/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1075 - accuracy: 0.9787 - graph_loss: 0.3455
Epoch 70/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0986 - accuracy: 0.9875 - graph_loss: 0.3403
Epoch 71/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1141 - accuracy: 0.9782 - graph_loss: 0.3508
Epoch 72/100
17/17 [==============================] - 0s 10ms/step - loss: 0.1012 - accuracy: 0.9814 - graph_loss: 0.3453
Epoch 73/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9833 - graph_loss: 0.3430
Epoch 74/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9842 - graph_loss: 0.3447
Epoch 75/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0988 - accuracy: 0.9842 - graph_loss: 0.3430
Epoch 76/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0915 - accuracy: 0.9856 - graph_loss: 0.3475
Epoch 77/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0960 - accuracy: 0.9833 - graph_loss: 0.3353
Epoch 78/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0916 - accuracy: 0.9838 - graph_loss: 0.3441
Epoch 79/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0979 - accuracy: 0.9800 - graph_loss: 0.3476
Epoch 80/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9782 - graph_loss: 0.3400
Epoch 81/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0978 - accuracy: 0.9838 - graph_loss: 0.3386
Epoch 82/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9805 - graph_loss: 0.3416
Epoch 83/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0957 - accuracy: 0.9838 - graph_loss: 0.3398
Epoch 84/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9879 - graph_loss: 0.3379
Epoch 85/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9838 - graph_loss: 0.3441
Epoch 86/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0906 - accuracy: 0.9847 - graph_loss: 0.3445
Epoch 87/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9852 - graph_loss: 0.3506
Epoch 88/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0821 - accuracy: 0.9898 - graph_loss: 0.3448
Epoch 89/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0803 - accuracy: 0.9865 - graph_loss: 0.3370
Epoch 90/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9828 - graph_loss: 0.3428
Epoch 91/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0887 - accuracy: 0.9852 - graph_loss: 0.3505
Epoch 92/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0882 - accuracy: 0.9847 - graph_loss: 0.3396
Epoch 93/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0807 - accuracy: 0.9879 - graph_loss: 0.3473
Epoch 94/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0820 - accuracy: 0.9861 - graph_loss: 0.3367
Epoch 95/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0864 - accuracy: 0.9838 - graph_loss: 0.3353
Epoch 96/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0786 - accuracy: 0.9889 - graph_loss: 0.3392
Epoch 97/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0735 - accuracy: 0.9912 - graph_loss: 0.3443
Epoch 98/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0861 - accuracy: 0.9842 - graph_loss: 0.3381
Epoch 99/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0850 - accuracy: 0.9833 - graph_loss: 0.3376
Epoch 100/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0841 - accuracy: 0.9879 - graph_loss: 0.3510

<tensorflow.python.keras.callbacks.History at 0x7fc3d853ce10>

Evaluar el modelo MLP con regularización de gráficos

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 6ms/step - loss: 1.2475 - accuracy: 0.8192


Eval accuracy for  MLP + graph regularization :  0.8191681504249573
Eval loss for  MLP + graph regularization :  1.2474583387374878

La precisión del modelo de gráfico regularizado es aproximadamente un 2-3% mayor que la del modelo base ( base_model ).

Conclusión

Hemos demostrado el uso de la regularización de gráficos para la clasificación de documentos en un gráfico de citas naturales (Cora) utilizando el marco de aprendizaje estructurado neuronal (NSL). Nuestro tutorial avanzado implica sintetizar gráficos basados ​​en incrustaciones de muestra antes de entrenar una red neuronal con regularización de gráficos. Este enfoque es útil si la entrada no contiene un gráfico explícito.

Alentamos a los usuarios a experimentar más variando la cantidad de supervisión y probando diferentes arquitecturas neuronales para la regularización de gráficos.