![]() | ![]() | ![]() |
visão global
A regularização de grafos é uma técnica específica sob o paradigma mais amplo de aprendizado de grafos neurais ( Bui et al., 2018 ). A ideia central é treinar modelos de rede neural com um objetivo regularizado por gráfico, aproveitando dados rotulados e não rotulados.
Neste tutorial, exploraremos o uso de regularização de gráfico para classificar documentos que formam um gráfico natural (orgânico).
A receita geral para a criação de um modelo regularizado por gráfico usando o framework Neural Structured Learning (NSL) é a seguinte:
- Gere dados de treinamento a partir do gráfico de entrada e recursos de amostra. Os nós no gráfico correspondem a amostras e as arestas no gráfico correspondem à similaridade entre pares de amostras. Os dados de treinamento resultantes conterão recursos vizinhos, além dos recursos do nó original.
- Crie uma rede neural como modelo básico usando a API sequencial, funcional ou de subclasse de
Keras
. - Envolva o modelo básico com a classe wrapper
GraphRegularization
, que é fornecida pela estrutura NSL, para criar um novo modeloKeras
gráfico. Este novo modelo incluirá uma perda de regularização de gráfico como termo de regularização em seu objetivo de treinamento. - Treine e avalie o modelo gráfico de
Keras
.
Configuração
Instale o pacote Neural Structured Learning.
pip install --quiet neural-structured-learning
Dependências e importações
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version: 2.2.0 Eager mode: True GPU is NOT AVAILABLE
Conjunto de dados Cora
O conjunto de dados Cora é um gráfico de citação onde os nós representam artigos de aprendizado de máquina e as bordas representam citações entre pares de artigos. A tarefa envolvida é a classificação de documentos, onde o objetivo é categorizar cada artigo em uma das 7 categorias. Em outras palavras, este é um problema de classificação multiclasse com 7 classes.
Gráfico
O gráfico original é direcionado. No entanto, para o propósito deste exemplo, consideramos a versão não direcionada deste gráfico. Portanto, se o artigo A cita o artigo B, também consideramos o artigo B como tendo citado A. Embora isso não seja necessariamente verdade, neste exemplo, consideramos as citações como um proxy de similaridade, que geralmente é uma propriedade comutativa.
Características
Cada artigo na entrada contém efetivamente 2 recursos:
Palavras : uma representação densa e multifacetada do texto no papel. O vocabulário para o conjunto de dados Cora contém 1433 palavras exclusivas. Portanto, o comprimento desse recurso é 1433, e o valor na posição 'i' é 0/1, indicando se a palavra 'i' no vocabulário existe no artigo fornecido ou não.
Rótulo : Um único inteiro que representa o ID de classe (categoria) do artigo.
Baixe o conjunto de dados Cora
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
Converta os dados Cora para o formato NSL
Para pré-processar o conjunto de dados Cora e convertê-lo para o formato exigido pelo Neural Structured Learning, executaremos o script 'preprocess_cora_dataset.py' , que está incluído no repositório github NSL. Este script faz o seguinte:
- Gere recursos vizinhos usando os recursos do nó original e o gráfico.
- Gere divisões de dados de treinamento e teste contendo instâncias de
tf.train.Example
. - Persista o trem resultante e os dados de teste no formato
TFRecord
.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2020-07-01 11:15:33-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.128.133, 151.101.64.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2020-07-01 11:15:33 (84.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.06 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.38 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.04 minutes.
Variáveis globais
Os caminhos do arquivo para os dados de treinamento e teste são baseados nos valores dos sinalizadores da linha de comando usados para invocar o script 'preprocess_cora_dataset.py' acima.
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
Hiperparâmetros
Usaremos uma instância de HParams
para incluir vários hiperparâmetros e constantes usados para treinamento e avaliação. Descrevemos resumidamente cada um deles abaixo:
num_classes : há um total de 7 classes diferentes
max_seq_length : Este é o tamanho do vocabulário e todas as instâncias na entrada têm uma representação densa, multi-hot, de pacote de palavras. Em outras palavras, um valor de 1 para uma palavra indica que a palavra está presente na entrada e um valor de 0 indica que não está.
distance_type : Esta é a métrica de distância usada para regularizar a amostra com seus vizinhos.
graph_regularization_multiplier : controla o peso relativo do termo de regularização do gráfico na função de perda geral.
num_neighbors : O número de vizinhos usados para regularização do gráfico. Este valor deve ser menor ou igual ao argumento da linha de comando
max_nbrs
usado acima ao executarpreprocess_cora_dataset.py
.num_fc_units : o número de camadas totalmente conectadas em nossa rede neural.
train_epochs : o número de épocas de treinamento.
batch_size : tamanho do lote usado para treinamento e avaliação.
dropout_rate : controla a taxa de abandono após cada camada totalmente conectada
eval_steps : O número de lotes a serem processados antes que a avaliação considerada seja concluída. Se definido como
None
, todas as instâncias no conjunto de teste são avaliadas.
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
Carregar dados de trem e teste
Conforme descrito anteriormente neste bloco de notas, o treinamento de entrada e os dados de teste foram criados pelo 'preprocess_cora_dataset.py' . Vamos carregá-los em dois objetostf.data.Dataset
- um para treinar e outro para teste.
Na camada de entrada de nosso modelo, extrairemos não apenas os recursos de 'palavras' e 'rótulo' de cada amostra, mas também os recursos de vizinho correspondentes com base no valor hparams.num_neighbors
. Instâncias com menos vizinhos do que hparams.num_neighbors
serão atribuídos a valores fictícios para esses recursos de vizinho inexistentes.
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
Vamos dar uma olhada no conjunto de dados do trem para ver seu conteúdo.
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [4 3 1 2 1 6 2 5 6 2 2 6 5 0 2 2 1 6 2 2 2 2 5 4 2 0 2 1 1 2 0 5 2 2 2 0 2 2 0 6 1 1 0 2 1 2 3 2 0 0 0 4 1 3 3 1 2 5 3 3 1 1 6 0 0 4 6 5 6 0 3 4 2 2 2 3 3 2 4 0 2 3 2 2 3 1 2 2 1 0 6 1 2 1 6 2 1 0 4 3 2 5 2 3 1 0 3 4 3 4 1 0 5 6 4 2 1 1 2 5 3 4 3 1 3 2 6 3], shape=(128,), dtype=int64)
Vamos dar uma olhada no conjunto de dados de teste para ver seu conteúdo.
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
Definição de modelo
Para demonstrar o uso da regularização de grafos, construímos primeiro um modelo básico para esse problema. Usaremos uma rede neural feed-forward simples com 2 camadas ocultas e dropout entre elas. Ilustramos a criação do modelo básico usando todos os tipos de modelo suportados pela estrutura tf.Keras
- sequencial, funcional e subclasse.
Modelo de base sequencial
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
return model
Modelo de base funcional
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(
hparams.num_classes, activation='softmax')(
cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
Modelo básico de subclasse
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(
hparams.num_classes, activation='softmax')
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
Criar modelo (s) de base
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 _________________________________________________________________ lambda (Lambda) (None, 1433) 0 _________________________________________________________________ dense (Dense) (None, 50) 71700 _________________________________________________________________ dropout (Dropout) (None, 50) 0 _________________________________________________________________ dense_1 (Dense) (None, 50) 2550 _________________________________________________________________ dropout_1 (Dropout) (None, 50) 0 _________________________________________________________________ dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74,607 Trainable params: 74,607 Non-trainable params: 0 _________________________________________________________________
Modelo MLP de base de treinamento
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 17/17 [==============================] - 0s 11ms/step - loss: 1.9256 - accuracy: 0.1870 Epoch 2/100 17/17 [==============================] - 0s 10ms/step - loss: 1.8410 - accuracy: 0.2835 Epoch 3/100 17/17 [==============================] - 0s 9ms/step - loss: 1.7479 - accuracy: 0.3374 Epoch 4/100 17/17 [==============================] - 0s 10ms/step - loss: 1.6384 - accuracy: 0.3884 Epoch 5/100 17/17 [==============================] - 0s 9ms/step - loss: 1.5086 - accuracy: 0.4390 Epoch 6/100 17/17 [==============================] - 0s 10ms/step - loss: 1.3606 - accuracy: 0.5016 Epoch 7/100 17/17 [==============================] - 0s 9ms/step - loss: 1.2165 - accuracy: 0.5791 Epoch 8/100 17/17 [==============================] - 0s 10ms/step - loss: 1.0783 - accuracy: 0.6311 Epoch 9/100 17/17 [==============================] - 0s 9ms/step - loss: 0.9552 - accuracy: 0.6947 Epoch 10/100 17/17 [==============================] - 0s 9ms/step - loss: 0.8680 - accuracy: 0.7090 Epoch 11/100 17/17 [==============================] - 0s 9ms/step - loss: 0.7915 - accuracy: 0.7425 Epoch 12/100 17/17 [==============================] - 0s 9ms/step - loss: 0.7124 - accuracy: 0.7773 Epoch 13/100 17/17 [==============================] - 0s 9ms/step - loss: 0.6582 - accuracy: 0.7907 Epoch 14/100 17/17 [==============================] - 0s 10ms/step - loss: 0.6021 - accuracy: 0.8065 Epoch 15/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5416 - accuracy: 0.8325 Epoch 16/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5042 - accuracy: 0.8473 Epoch 17/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4433 - accuracy: 0.8761 Epoch 18/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4310 - accuracy: 0.8640 Epoch 19/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3894 - accuracy: 0.8840 Epoch 20/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3676 - accuracy: 0.8891 Epoch 21/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3576 - accuracy: 0.8812 Epoch 22/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3132 - accuracy: 0.9067 Epoch 23/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3058 - accuracy: 0.9142 Epoch 24/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2924 - accuracy: 0.9155 Epoch 25/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2769 - accuracy: 0.9197 Epoch 26/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2636 - accuracy: 0.9244 Epoch 27/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2429 - accuracy: 0.9313 Epoch 28/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2324 - accuracy: 0.9323 Epoch 29/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2285 - accuracy: 0.9346 Epoch 30/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2039 - accuracy: 0.9374 Epoch 31/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1943 - accuracy: 0.9471 Epoch 32/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1898 - accuracy: 0.9439 Epoch 33/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1879 - accuracy: 0.9425 Epoch 34/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1828 - accuracy: 0.9443 Epoch 35/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1635 - accuracy: 0.9541 Epoch 36/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1648 - accuracy: 0.9476 Epoch 37/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1603 - accuracy: 0.9499 Epoch 38/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1428 - accuracy: 0.9624 Epoch 39/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1483 - accuracy: 0.9601 Epoch 40/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1352 - accuracy: 0.9582 Epoch 41/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1379 - accuracy: 0.9555 Epoch 42/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1410 - accuracy: 0.9582 Epoch 43/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1198 - accuracy: 0.9684 Epoch 44/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1148 - accuracy: 0.9731 Epoch 45/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1228 - accuracy: 0.9657 Epoch 46/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1135 - accuracy: 0.9703 Epoch 47/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1134 - accuracy: 0.9661 Epoch 48/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1175 - accuracy: 0.9619 Epoch 49/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1002 - accuracy: 0.9703 Epoch 50/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1143 - accuracy: 0.9671 Epoch 51/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0923 - accuracy: 0.9777 Epoch 52/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1068 - accuracy: 0.9731 Epoch 53/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0972 - accuracy: 0.9712 Epoch 54/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0828 - accuracy: 0.9796 Epoch 55/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1036 - accuracy: 0.9703 Epoch 56/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0954 - accuracy: 0.9745 Epoch 57/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0883 - accuracy: 0.9768 Epoch 58/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0859 - accuracy: 0.9777 Epoch 59/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0856 - accuracy: 0.9759 Epoch 60/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0858 - accuracy: 0.9754 Epoch 61/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0848 - accuracy: 0.9726 Epoch 62/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0840 - accuracy: 0.9763 Epoch 63/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0770 - accuracy: 0.9805 Epoch 64/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0823 - accuracy: 0.9745 Epoch 65/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0665 - accuracy: 0.9828 Epoch 66/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0788 - accuracy: 0.9777 Epoch 67/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0690 - accuracy: 0.9800 Epoch 68/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0683 - accuracy: 0.9805 Epoch 69/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0615 - accuracy: 0.9838 Epoch 70/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0618 - accuracy: 0.9833 Epoch 71/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0659 - accuracy: 0.9810 Epoch 72/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0704 - accuracy: 0.9800 Epoch 73/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0645 - accuracy: 0.9814 Epoch 74/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0645 - accuracy: 0.9791 Epoch 75/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0638 - accuracy: 0.9791 Epoch 76/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0648 - accuracy: 0.9814 Epoch 77/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0591 - accuracy: 0.9838 Epoch 78/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0606 - accuracy: 0.9861 Epoch 79/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0699 - accuracy: 0.9814 Epoch 80/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0603 - accuracy: 0.9828 Epoch 81/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0629 - accuracy: 0.9828 Epoch 82/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0596 - accuracy: 0.9828 Epoch 83/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0542 - accuracy: 0.9828 Epoch 84/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0452 - accuracy: 0.9893 Epoch 85/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0551 - accuracy: 0.9838 Epoch 86/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0555 - accuracy: 0.9842 Epoch 87/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0514 - accuracy: 0.9824 Epoch 88/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0553 - accuracy: 0.9847 Epoch 89/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0475 - accuracy: 0.9884 Epoch 90/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0476 - accuracy: 0.9893 Epoch 91/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0427 - accuracy: 0.9903 Epoch 92/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0475 - accuracy: 0.9847 Epoch 93/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0423 - accuracy: 0.9893 Epoch 94/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0473 - accuracy: 0.9865 Epoch 95/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0560 - accuracy: 0.9819 Epoch 96/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0547 - accuracy: 0.9810 Epoch 97/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0576 - accuracy: 0.9814 Epoch 98/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0429 - accuracy: 0.9893 Epoch 99/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0440 - accuracy: 0.9875 Epoch 100/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0513 - accuracy: 0.9838 <tensorflow.python.keras.callbacks.History at 0x7fc47a3c78d0>
Avalie o modelo básico de MLP
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.3380 - accuracy: 0.7740 Eval accuracy for Base MLP model : 0.7739602327346802 Eval loss for Base MLP model : 1.3379606008529663
Treine o modelo MLP com regularização de gráfico
Incorporar a regularização do gráfico no termo de perda de um tf.Keras.Model
existente requer apenas algumas linhas de código. O modelo básico é empacotado para criar um novo modelo de subclasse tf.Keras
, cuja perda inclui a regularização do gráfico.
Para avaliar o benefício incremental da regularização do gráfico, criaremos uma nova instância do modelo base. Isso ocorre porque o base_model
já foi treinado para algumas iterações e reutilizar esse modelo treinado para criar um modelo regularizado por gráfico não será uma comparação justa para o base_model
.
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:434: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory. "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " 17/17 [==============================] - 0s 10ms/step - loss: 1.9454 - accuracy: 0.1652 - graph_loss: 0.0076 Epoch 2/100 17/17 [==============================] - 0s 10ms/step - loss: 1.8517 - accuracy: 0.2956 - graph_loss: 0.0117 Epoch 3/100 17/17 [==============================] - 0s 10ms/step - loss: 1.7589 - accuracy: 0.3151 - graph_loss: 0.0261 Epoch 4/100 17/17 [==============================] - 0s 10ms/step - loss: 1.6714 - accuracy: 0.3392 - graph_loss: 0.0476 Epoch 5/100 17/17 [==============================] - 0s 9ms/step - loss: 1.5607 - accuracy: 0.4037 - graph_loss: 0.0622 Epoch 6/100 17/17 [==============================] - 0s 10ms/step - loss: 1.4486 - accuracy: 0.4807 - graph_loss: 0.0921 Epoch 7/100 17/17 [==============================] - 0s 10ms/step - loss: 1.3135 - accuracy: 0.5383 - graph_loss: 0.1236 Epoch 8/100 17/17 [==============================] - 0s 10ms/step - loss: 1.1902 - accuracy: 0.5912 - graph_loss: 0.1616 Epoch 9/100 17/17 [==============================] - 0s 10ms/step - loss: 1.0647 - accuracy: 0.6575 - graph_loss: 0.1920 Epoch 10/100 17/17 [==============================] - 0s 9ms/step - loss: 0.9416 - accuracy: 0.7067 - graph_loss: 0.2181 Epoch 11/100 17/17 [==============================] - 0s 10ms/step - loss: 0.8601 - accuracy: 0.7378 - graph_loss: 0.2470 Epoch 12/100 17/17 [==============================] - 0s 9ms/step - loss: 0.7968 - accuracy: 0.7462 - graph_loss: 0.2565 Epoch 13/100 17/17 [==============================] - 0s 10ms/step - loss: 0.6881 - accuracy: 0.7912 - graph_loss: 0.2681 Epoch 14/100 17/17 [==============================] - 0s 10ms/step - loss: 0.6548 - accuracy: 0.8139 - graph_loss: 0.2941 Epoch 15/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5874 - accuracy: 0.8376 - graph_loss: 0.3010 Epoch 16/100 17/17 [==============================] - 0s 9ms/step - loss: 0.5537 - accuracy: 0.8348 - graph_loss: 0.3014 Epoch 17/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5123 - accuracy: 0.8529 - graph_loss: 0.3097 Epoch 18/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4771 - accuracy: 0.8640 - graph_loss: 0.3192 Epoch 19/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4294 - accuracy: 0.8826 - graph_loss: 0.3182 Epoch 20/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4109 - accuracy: 0.8854 - graph_loss: 0.3169 Epoch 21/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3901 - accuracy: 0.8965 - graph_loss: 0.3250 Epoch 22/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3700 - accuracy: 0.8956 - graph_loss: 0.3349 Epoch 23/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3716 - accuracy: 0.8974 - graph_loss: 0.3408 Epoch 24/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3258 - accuracy: 0.9202 - graph_loss: 0.3361 Epoch 25/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3043 - accuracy: 0.9253 - graph_loss: 0.3351 Epoch 26/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2919 - accuracy: 0.9253 - graph_loss: 0.3361 Epoch 27/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3005 - accuracy: 0.9202 - graph_loss: 0.3249 Epoch 28/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2629 - accuracy: 0.9336 - graph_loss: 0.3442 Epoch 29/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2617 - accuracy: 0.9401 - graph_loss: 0.3302 Epoch 30/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2510 - accuracy: 0.9383 - graph_loss: 0.3436 Epoch 31/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2452 - accuracy: 0.9411 - graph_loss: 0.3364 Epoch 32/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2397 - accuracy: 0.9466 - graph_loss: 0.3333 Epoch 33/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2239 - accuracy: 0.9466 - graph_loss: 0.3373 Epoch 34/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2084 - accuracy: 0.9513 - graph_loss: 0.3330 Epoch 35/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2075 - accuracy: 0.9499 - graph_loss: 0.3383 Epoch 36/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2064 - accuracy: 0.9513 - graph_loss: 0.3394 Epoch 37/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1857 - accuracy: 0.9568 - graph_loss: 0.3371 Epoch 38/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1799 - accuracy: 0.9601 - graph_loss: 0.3477 Epoch 39/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1844 - accuracy: 0.9573 - graph_loss: 0.3385 Epoch 40/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1823 - accuracy: 0.9592 - graph_loss: 0.3445 Epoch 41/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1713 - accuracy: 0.9615 - graph_loss: 0.3451 Epoch 42/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1669 - accuracy: 0.9624 - graph_loss: 0.3398 Epoch 43/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1692 - accuracy: 0.9671 - graph_loss: 0.3483 Epoch 44/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1605 - accuracy: 0.9647 - graph_loss: 0.3437 Epoch 45/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1485 - accuracy: 0.9703 - graph_loss: 0.3338 Epoch 46/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1467 - accuracy: 0.9717 - graph_loss: 0.3405 Epoch 47/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1492 - accuracy: 0.9694 - graph_loss: 0.3466 Epoch 48/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1577 - accuracy: 0.9666 - graph_loss: 0.3338 Epoch 49/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1363 - accuracy: 0.9773 - graph_loss: 0.3424 Epoch 50/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1511 - accuracy: 0.9694 - graph_loss: 0.3402 Epoch 51/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1366 - accuracy: 0.9759 - graph_loss: 0.3385 Epoch 52/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1254 - accuracy: 0.9777 - graph_loss: 0.3474 Epoch 53/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1289 - accuracy: 0.9740 - graph_loss: 0.3469 Epoch 54/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1410 - accuracy: 0.9689 - graph_loss: 0.3475 Epoch 55/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1356 - accuracy: 0.9703 - graph_loss: 0.3483 Epoch 56/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1283 - accuracy: 0.9773 - graph_loss: 0.3412 Epoch 57/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1264 - accuracy: 0.9745 - graph_loss: 0.3473 Epoch 58/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1242 - accuracy: 0.9740 - graph_loss: 0.3443 Epoch 59/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1144 - accuracy: 0.9782 - graph_loss: 0.3440 Epoch 60/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1250 - accuracy: 0.9735 - graph_loss: 0.3357 Epoch 61/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1190 - accuracy: 0.9787 - graph_loss: 0.3400 Epoch 62/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1141 - accuracy: 0.9814 - graph_loss: 0.3419 Epoch 63/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1085 - accuracy: 0.9787 - graph_loss: 0.3395 Epoch 64/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1148 - accuracy: 0.9768 - graph_loss: 0.3504 Epoch 65/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1137 - accuracy: 0.9791 - graph_loss: 0.3360 Epoch 66/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1121 - accuracy: 0.9745 - graph_loss: 0.3469 Epoch 67/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1046 - accuracy: 0.9810 - graph_loss: 0.3476 Epoch 68/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1112 - accuracy: 0.9791 - graph_loss: 0.3431 Epoch 69/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1075 - accuracy: 0.9787 - graph_loss: 0.3455 Epoch 70/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0986 - accuracy: 0.9875 - graph_loss: 0.3403 Epoch 71/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1141 - accuracy: 0.9782 - graph_loss: 0.3508 Epoch 72/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1012 - accuracy: 0.9814 - graph_loss: 0.3453 Epoch 73/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9833 - graph_loss: 0.3430 Epoch 74/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9842 - graph_loss: 0.3447 Epoch 75/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0988 - accuracy: 0.9842 - graph_loss: 0.3430 Epoch 76/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0915 - accuracy: 0.9856 - graph_loss: 0.3475 Epoch 77/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0960 - accuracy: 0.9833 - graph_loss: 0.3353 Epoch 78/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0916 - accuracy: 0.9838 - graph_loss: 0.3441 Epoch 79/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0979 - accuracy: 0.9800 - graph_loss: 0.3476 Epoch 80/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9782 - graph_loss: 0.3400 Epoch 81/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0978 - accuracy: 0.9838 - graph_loss: 0.3386 Epoch 82/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9805 - graph_loss: 0.3416 Epoch 83/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0957 - accuracy: 0.9838 - graph_loss: 0.3398 Epoch 84/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9879 - graph_loss: 0.3379 Epoch 85/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9838 - graph_loss: 0.3441 Epoch 86/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0906 - accuracy: 0.9847 - graph_loss: 0.3445 Epoch 87/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9852 - graph_loss: 0.3506 Epoch 88/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0821 - accuracy: 0.9898 - graph_loss: 0.3448 Epoch 89/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0803 - accuracy: 0.9865 - graph_loss: 0.3370 Epoch 90/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9828 - graph_loss: 0.3428 Epoch 91/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0887 - accuracy: 0.9852 - graph_loss: 0.3505 Epoch 92/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0882 - accuracy: 0.9847 - graph_loss: 0.3396 Epoch 93/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0807 - accuracy: 0.9879 - graph_loss: 0.3473 Epoch 94/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0820 - accuracy: 0.9861 - graph_loss: 0.3367 Epoch 95/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0864 - accuracy: 0.9838 - graph_loss: 0.3353 Epoch 96/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0786 - accuracy: 0.9889 - graph_loss: 0.3392 Epoch 97/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0735 - accuracy: 0.9912 - graph_loss: 0.3443 Epoch 98/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0861 - accuracy: 0.9842 - graph_loss: 0.3381 Epoch 99/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0850 - accuracy: 0.9833 - graph_loss: 0.3376 Epoch 100/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0841 - accuracy: 0.9879 - graph_loss: 0.3510 <tensorflow.python.keras.callbacks.History at 0x7fc3d853ce10>
Avalie o modelo MLP com regularização de gráfico
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 6ms/step - loss: 1.2475 - accuracy: 0.8192 Eval accuracy for MLP + graph regularization : 0.8191681504249573 Eval loss for MLP + graph regularization : 1.2474583387374878
A precisão do modelo regularizado por gráfico é cerca de 2-3% maior do que a do modelo base ( base_model
).
Conclusão
Demonstramos o uso de regularização de gráfico para classificação de documentos em um gráfico de citação natural (Cora) usando o framework Neural Structured Learning (NSL). Nosso tutorial avançado envolve a síntese de gráficos com base em embeddings de amostra antes de treinar uma rede neural com regularização de gráfico. Essa abordagem é útil se a entrada não contém um gráfico explícito.
Encorajamos os usuários a experimentar mais, variando a quantidade de supervisão, bem como tentando diferentes arquiteturas neurais para regularização de gráfico.