![]() | ![]() | ![]() |
סקירה כללית
רגולציה של גרפים היא טכניקה ספציפית תחת הפרדיגמה הרחבה יותר של למידת גרפים עצביים ( Bui et al., 2018 ). הרעיון המרכזי הוא להכשיר מודלים של רשת עצבית עם מטרה מוסדרת לגרף, הרותמת נתונים מסומנים ובלתי מתויגים.
במדריך זה נחקור את השימוש בהסדרת גרפים לסיווג מסמכים המהווים גרף טבעי (אורגני).
המתכון הכללי ליצירת מודל מוסדר גרפי תוך שימוש במסגרת הלמידה המובנית העצבית (NSL) הוא כדלקמן:
- צור נתוני אימון מגרף הקלט ותכונות הדוגמה. צמתים בגרף תואמים לדוגמאות וקצוות בתרשים תואמים לדמיון בין זוגות דגימות. נתוני האימון המתקבלים יכילו תכונות שכנות בנוסף לתכונות הצומת המקוריות.
- צור רשת עצבית כמודל בסיס באמצעות ממשק ה- API הרציף, הפונקציונלי או המשנה של
Keras
. - עטוף את מודל הבסיס
GraphRegularization
העטיפהGraphRegularization
, המסופקת על ידי מסגרת NSL, כדי ליצור מודלKeras
גרפי חדש. מודל חדש זה יכלול הפסד רגולציה של גרף כמונח הסדרה במטרת האימון שלו. - התאמן והעריך את מודל
Keras
.
להכין
התקן את חבילת הלמידה המובנית העצבית.
pip install --quiet neural-structured-learning
תלות ויבוא
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version: 2.2.0 Eager mode: True GPU is NOT AVAILABLE
מערך Cora
מערך הנתונים של Cora הוא גרף ציטוטים בו צמתים מייצגים ניירות למידת מכונה וקצוות מייצגים ציטוטים בין זוגות ניירות. המשימה הכרוכה בה היא סיווג מסמכים כאשר המטרה היא לסווג כל מאמר לאחת משבע קטגוריות. במילים אחרות, זו בעיית סיווג רב-מעמדית עם 7 כיתות.
גרָף
הגרף המקורי מכוון. עם זאת, לצורך דוגמה זו אנו רואים את הגרסה הלא מכוונת של גרף זה. לכן, אם נייר A מצטט את הנייר B, אנו מחשיבים גם את הנייר B שציטט את A. למרות שזה לא בהכרח נכון, בדוגמה זו אנו רואים ציטוטים כמשל לדמיון, שהוא בדרך כלל תכונה קומוטטיבית.
תכונות
כל נייר בקלט מכיל למעשה 2 תכונות:
מילים : ייצוג שקית- מילים צפוף ורב-חם של הטקסט בעיתון. אוצר המילים של מערך הנתונים של קורה מכיל 1433 מילים ייחודיות. אז, אורך תכונה זו הוא 1433, והערך במיקום 'i' הוא 0/1 המציין אם המילה 'i' באוצר המילים קיימת בעיתון הנתון או לא.
תווית : מספר שלם יחיד המייצג את מזהה הכיתה (קטגוריה) של הנייר.
הורד את מערך הנתונים של קורה
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
המר את נתוני הקורה לפורמט NSL
על מנת לעבד את מערך הנתונים של קורה ולהמיר אותו לפורמט הנדרש על ידי למידה מובנית עצבית, נפעיל את סקריפט 'preprocess_cora_dataset.py' , הכלול במאגר github של NSL. סקריפט זה מבצע את הפעולות הבאות:
- צור תכונות שכנות באמצעות תכונות הצומת המקוריות והגרף.
- צור פיצולי נתוני רכבת ובדיקה המכילים
tf.train.Example
.tf.train.Example
לדוגמא. - המשך בנתוני הרכבת וכתוצאה מכך בפורמט
TFRecord
.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2020-07-01 11:15:33-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.128.133, 151.101.64.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2020-07-01 11:15:33 (84.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.06 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.38 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.04 minutes.
משתנים גלובליים
נתיבי הקבצים לרכבת ולנתוני הבדיקה מבוססים על ערכי דגל שורת הפקודה המשמשים להפעלת התסריט 'preprocess_cora_dataset.py' לעיל.
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
היפרפרמטרים
אנו נשתמש במופע של HParams
כדי לכלול hyperparameters וקבועים שונים המשמשים לאימון והערכה. אנו מתארים בקצרה כל אחד מהם להלן:
מספר_קלאסים : ישנם בסך הכל 7 כיתות שונות
max_seq_length : זהו גודל אוצר המילים ולכל המקרים בקלט יש ייצוג רב-חם ושקית-מילים צפוף. במילים אחרות, ערך 1 למילה מציין שהמילה קיימת בקלט וערך 0 מציין שהיא לא.
distance_type : זהו מדד המרחק המשמש להסדרת המדגם עם שכניו.
graph_regularization_multiplier : זה שולט במשקל היחסי של מונח ויסות הגרף בפונקציית האובדן הכללי.
מספר_שכנים : מספר השכנים המשמשים להסדרת גרפים. ערך זה צריך להיות קטן או שווה
max_nbrs
שורת הפקודהmax_nbrs
המשמש לעיל בעת הפעלתpreprocess_cora_dataset.py
.num_fc_units : מספר השכבות המחוברות לחלוטין ברשת העצבית שלנו.
train_epochs : מספר תקופות האימון.
batch_size : גודל האצווה המשמש לאימון והערכה.
dropout_rate : שולט בקצב הנשירה בעקבות כל שכבה מחוברת לחלוטין
eval_steps : מספר האצוותים לעיבוד לפני השלמת ההערכה. אם מוגדר כ-
None
, כל המופעים בערכת הבדיקה מוערכים.
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
טען נתוני רכבת ומבחנים
כמתואר קודם במחברת זו, הכשרת הקלט ונתוני הבדיקה נוצרו על ידי 'preprocess_cora_dataset.py' . נטען אותם לשני אובייקטיםtf.data.Dataset
- אחד לרכבת ואחד לבדיקה.
בשכבת הקלט של המודל שלנו, נשלוף לא רק את התכונות 'מילים' hparams.num_neighbors
'מכל מדגם, אלא גם תכונות שכנות תואמות בהתבסס על הערך hparams.num_neighbors
. hparams.num_neighbors
עם פחות שכנים מ- hparams.num_neighbors
יוקצו ערכי דמה עבור אותן תכונות שכנות שאינן קיימות.
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
בואו נציץ במערך הרכבות בכדי לבדוק את תוכנו.
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [4 3 1 2 1 6 2 5 6 2 2 6 5 0 2 2 1 6 2 2 2 2 5 4 2 0 2 1 1 2 0 5 2 2 2 0 2 2 0 6 1 1 0 2 1 2 3 2 0 0 0 4 1 3 3 1 2 5 3 3 1 1 6 0 0 4 6 5 6 0 3 4 2 2 2 3 3 2 4 0 2 3 2 2 3 1 2 2 1 0 6 1 2 1 6 2 1 0 4 3 2 5 2 3 1 0 3 4 3 4 1 0 5 6 4 2 1 1 2 5 3 4 3 1 3 2 6 3], shape=(128,), dtype=int64)
בואו נציץ במערך הבדיקה כדי לבדוק את תוכנו.
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
הגדרת מודל
על מנת להדגים את השימוש בסדירות גרפים, אנו בונים תחילה מודל בסיס לבעיה זו. נשתמש ברשת עצבית פשוטה להזנה עם 2 שכבות נסתרות ונשירה בין לבין. אנו ממחישים את יצירת מודל הבסיס תוך שימוש בכל סוגי הדגמים הנתמכים על ידי מסגרת tf.Keras
- רצף, פונקציונלי tf.Keras
-מחלקה.
מודל בסיס רציף
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
return model
מודל בסיס פונקציונלי
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(
hparams.num_classes, activation='softmax')(
cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
דגם בסיס מחלקה משנה
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(
hparams.num_classes, activation='softmax')
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
צור מודלים בסיסיים
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 _________________________________________________________________ lambda (Lambda) (None, 1433) 0 _________________________________________________________________ dense (Dense) (None, 50) 71700 _________________________________________________________________ dropout (Dropout) (None, 50) 0 _________________________________________________________________ dense_1 (Dense) (None, 50) 2550 _________________________________________________________________ dropout_1 (Dropout) (None, 50) 0 _________________________________________________________________ dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74,607 Trainable params: 74,607 Non-trainable params: 0 _________________________________________________________________
דגם MLP בסיס רכבת
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 17/17 [==============================] - 0s 11ms/step - loss: 1.9256 - accuracy: 0.1870 Epoch 2/100 17/17 [==============================] - 0s 10ms/step - loss: 1.8410 - accuracy: 0.2835 Epoch 3/100 17/17 [==============================] - 0s 9ms/step - loss: 1.7479 - accuracy: 0.3374 Epoch 4/100 17/17 [==============================] - 0s 10ms/step - loss: 1.6384 - accuracy: 0.3884 Epoch 5/100 17/17 [==============================] - 0s 9ms/step - loss: 1.5086 - accuracy: 0.4390 Epoch 6/100 17/17 [==============================] - 0s 10ms/step - loss: 1.3606 - accuracy: 0.5016 Epoch 7/100 17/17 [==============================] - 0s 9ms/step - loss: 1.2165 - accuracy: 0.5791 Epoch 8/100 17/17 [==============================] - 0s 10ms/step - loss: 1.0783 - accuracy: 0.6311 Epoch 9/100 17/17 [==============================] - 0s 9ms/step - loss: 0.9552 - accuracy: 0.6947 Epoch 10/100 17/17 [==============================] - 0s 9ms/step - loss: 0.8680 - accuracy: 0.7090 Epoch 11/100 17/17 [==============================] - 0s 9ms/step - loss: 0.7915 - accuracy: 0.7425 Epoch 12/100 17/17 [==============================] - 0s 9ms/step - loss: 0.7124 - accuracy: 0.7773 Epoch 13/100 17/17 [==============================] - 0s 9ms/step - loss: 0.6582 - accuracy: 0.7907 Epoch 14/100 17/17 [==============================] - 0s 10ms/step - loss: 0.6021 - accuracy: 0.8065 Epoch 15/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5416 - accuracy: 0.8325 Epoch 16/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5042 - accuracy: 0.8473 Epoch 17/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4433 - accuracy: 0.8761 Epoch 18/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4310 - accuracy: 0.8640 Epoch 19/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3894 - accuracy: 0.8840 Epoch 20/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3676 - accuracy: 0.8891 Epoch 21/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3576 - accuracy: 0.8812 Epoch 22/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3132 - accuracy: 0.9067 Epoch 23/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3058 - accuracy: 0.9142 Epoch 24/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2924 - accuracy: 0.9155 Epoch 25/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2769 - accuracy: 0.9197 Epoch 26/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2636 - accuracy: 0.9244 Epoch 27/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2429 - accuracy: 0.9313 Epoch 28/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2324 - accuracy: 0.9323 Epoch 29/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2285 - accuracy: 0.9346 Epoch 30/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2039 - accuracy: 0.9374 Epoch 31/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1943 - accuracy: 0.9471 Epoch 32/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1898 - accuracy: 0.9439 Epoch 33/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1879 - accuracy: 0.9425 Epoch 34/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1828 - accuracy: 0.9443 Epoch 35/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1635 - accuracy: 0.9541 Epoch 36/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1648 - accuracy: 0.9476 Epoch 37/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1603 - accuracy: 0.9499 Epoch 38/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1428 - accuracy: 0.9624 Epoch 39/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1483 - accuracy: 0.9601 Epoch 40/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1352 - accuracy: 0.9582 Epoch 41/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1379 - accuracy: 0.9555 Epoch 42/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1410 - accuracy: 0.9582 Epoch 43/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1198 - accuracy: 0.9684 Epoch 44/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1148 - accuracy: 0.9731 Epoch 45/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1228 - accuracy: 0.9657 Epoch 46/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1135 - accuracy: 0.9703 Epoch 47/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1134 - accuracy: 0.9661 Epoch 48/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1175 - accuracy: 0.9619 Epoch 49/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1002 - accuracy: 0.9703 Epoch 50/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1143 - accuracy: 0.9671 Epoch 51/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0923 - accuracy: 0.9777 Epoch 52/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1068 - accuracy: 0.9731 Epoch 53/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0972 - accuracy: 0.9712 Epoch 54/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0828 - accuracy: 0.9796 Epoch 55/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1036 - accuracy: 0.9703 Epoch 56/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0954 - accuracy: 0.9745 Epoch 57/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0883 - accuracy: 0.9768 Epoch 58/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0859 - accuracy: 0.9777 Epoch 59/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0856 - accuracy: 0.9759 Epoch 60/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0858 - accuracy: 0.9754 Epoch 61/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0848 - accuracy: 0.9726 Epoch 62/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0840 - accuracy: 0.9763 Epoch 63/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0770 - accuracy: 0.9805 Epoch 64/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0823 - accuracy: 0.9745 Epoch 65/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0665 - accuracy: 0.9828 Epoch 66/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0788 - accuracy: 0.9777 Epoch 67/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0690 - accuracy: 0.9800 Epoch 68/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0683 - accuracy: 0.9805 Epoch 69/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0615 - accuracy: 0.9838 Epoch 70/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0618 - accuracy: 0.9833 Epoch 71/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0659 - accuracy: 0.9810 Epoch 72/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0704 - accuracy: 0.9800 Epoch 73/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0645 - accuracy: 0.9814 Epoch 74/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0645 - accuracy: 0.9791 Epoch 75/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0638 - accuracy: 0.9791 Epoch 76/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0648 - accuracy: 0.9814 Epoch 77/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0591 - accuracy: 0.9838 Epoch 78/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0606 - accuracy: 0.9861 Epoch 79/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0699 - accuracy: 0.9814 Epoch 80/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0603 - accuracy: 0.9828 Epoch 81/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0629 - accuracy: 0.9828 Epoch 82/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0596 - accuracy: 0.9828 Epoch 83/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0542 - accuracy: 0.9828 Epoch 84/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0452 - accuracy: 0.9893 Epoch 85/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0551 - accuracy: 0.9838 Epoch 86/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0555 - accuracy: 0.9842 Epoch 87/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0514 - accuracy: 0.9824 Epoch 88/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0553 - accuracy: 0.9847 Epoch 89/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0475 - accuracy: 0.9884 Epoch 90/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0476 - accuracy: 0.9893 Epoch 91/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0427 - accuracy: 0.9903 Epoch 92/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0475 - accuracy: 0.9847 Epoch 93/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0423 - accuracy: 0.9893 Epoch 94/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0473 - accuracy: 0.9865 Epoch 95/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0560 - accuracy: 0.9819 Epoch 96/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0547 - accuracy: 0.9810 Epoch 97/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0576 - accuracy: 0.9814 Epoch 98/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0429 - accuracy: 0.9893 Epoch 99/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0440 - accuracy: 0.9875 Epoch 100/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0513 - accuracy: 0.9838 <tensorflow.python.keras.callbacks.History at 0x7fc47a3c78d0>
הערך את מודל ה- MLP הבסיסי
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.3380 - accuracy: 0.7740 Eval accuracy for Base MLP model : 0.7739602327346802 Eval loss for Base MLP model : 1.3379606008529663
הרכבת מודל MLP עם ויסות גרף
שילוב tf.Keras.Model
של tf.Keras.Model
לטווח ההפסד של tf.Keras.Model
קיים רק מספר שורות קוד. המודל הבסיסי עטוף ליצירת מודל משנה חדש מסוג tf.Keras
, שההפסד שלו כולל tf.Keras
גרפים.
כדי להעריך את התועלת המצטברת של רגולציה של גרפים, ניצור מופע מודל בסיס חדש. הסיבה לכך היא ש- base_model
כבר הוכשרה למספר איטרציות, ושימוש חוזר במודל מאומן זה ליצירת מודל base_model
גרף לא יהווה השוואה הוגנת עבור base_model
.
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:434: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory. "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " 17/17 [==============================] - 0s 10ms/step - loss: 1.9454 - accuracy: 0.1652 - graph_loss: 0.0076 Epoch 2/100 17/17 [==============================] - 0s 10ms/step - loss: 1.8517 - accuracy: 0.2956 - graph_loss: 0.0117 Epoch 3/100 17/17 [==============================] - 0s 10ms/step - loss: 1.7589 - accuracy: 0.3151 - graph_loss: 0.0261 Epoch 4/100 17/17 [==============================] - 0s 10ms/step - loss: 1.6714 - accuracy: 0.3392 - graph_loss: 0.0476 Epoch 5/100 17/17 [==============================] - 0s 9ms/step - loss: 1.5607 - accuracy: 0.4037 - graph_loss: 0.0622 Epoch 6/100 17/17 [==============================] - 0s 10ms/step - loss: 1.4486 - accuracy: 0.4807 - graph_loss: 0.0921 Epoch 7/100 17/17 [==============================] - 0s 10ms/step - loss: 1.3135 - accuracy: 0.5383 - graph_loss: 0.1236 Epoch 8/100 17/17 [==============================] - 0s 10ms/step - loss: 1.1902 - accuracy: 0.5912 - graph_loss: 0.1616 Epoch 9/100 17/17 [==============================] - 0s 10ms/step - loss: 1.0647 - accuracy: 0.6575 - graph_loss: 0.1920 Epoch 10/100 17/17 [==============================] - 0s 9ms/step - loss: 0.9416 - accuracy: 0.7067 - graph_loss: 0.2181 Epoch 11/100 17/17 [==============================] - 0s 10ms/step - loss: 0.8601 - accuracy: 0.7378 - graph_loss: 0.2470 Epoch 12/100 17/17 [==============================] - 0s 9ms/step - loss: 0.7968 - accuracy: 0.7462 - graph_loss: 0.2565 Epoch 13/100 17/17 [==============================] - 0s 10ms/step - loss: 0.6881 - accuracy: 0.7912 - graph_loss: 0.2681 Epoch 14/100 17/17 [==============================] - 0s 10ms/step - loss: 0.6548 - accuracy: 0.8139 - graph_loss: 0.2941 Epoch 15/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5874 - accuracy: 0.8376 - graph_loss: 0.3010 Epoch 16/100 17/17 [==============================] - 0s 9ms/step - loss: 0.5537 - accuracy: 0.8348 - graph_loss: 0.3014 Epoch 17/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5123 - accuracy: 0.8529 - graph_loss: 0.3097 Epoch 18/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4771 - accuracy: 0.8640 - graph_loss: 0.3192 Epoch 19/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4294 - accuracy: 0.8826 - graph_loss: 0.3182 Epoch 20/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4109 - accuracy: 0.8854 - graph_loss: 0.3169 Epoch 21/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3901 - accuracy: 0.8965 - graph_loss: 0.3250 Epoch 22/100 17/17 [==============================] - 0s 9ms/step - loss: 0.3700 - accuracy: 0.8956 - graph_loss: 0.3349 Epoch 23/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3716 - accuracy: 0.8974 - graph_loss: 0.3408 Epoch 24/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3258 - accuracy: 0.9202 - graph_loss: 0.3361 Epoch 25/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3043 - accuracy: 0.9253 - graph_loss: 0.3351 Epoch 26/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2919 - accuracy: 0.9253 - graph_loss: 0.3361 Epoch 27/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3005 - accuracy: 0.9202 - graph_loss: 0.3249 Epoch 28/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2629 - accuracy: 0.9336 - graph_loss: 0.3442 Epoch 29/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2617 - accuracy: 0.9401 - graph_loss: 0.3302 Epoch 30/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2510 - accuracy: 0.9383 - graph_loss: 0.3436 Epoch 31/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2452 - accuracy: 0.9411 - graph_loss: 0.3364 Epoch 32/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2397 - accuracy: 0.9466 - graph_loss: 0.3333 Epoch 33/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2239 - accuracy: 0.9466 - graph_loss: 0.3373 Epoch 34/100 17/17 [==============================] - 0s 9ms/step - loss: 0.2084 - accuracy: 0.9513 - graph_loss: 0.3330 Epoch 35/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2075 - accuracy: 0.9499 - graph_loss: 0.3383 Epoch 36/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2064 - accuracy: 0.9513 - graph_loss: 0.3394 Epoch 37/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1857 - accuracy: 0.9568 - graph_loss: 0.3371 Epoch 38/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1799 - accuracy: 0.9601 - graph_loss: 0.3477 Epoch 39/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1844 - accuracy: 0.9573 - graph_loss: 0.3385 Epoch 40/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1823 - accuracy: 0.9592 - graph_loss: 0.3445 Epoch 41/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1713 - accuracy: 0.9615 - graph_loss: 0.3451 Epoch 42/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1669 - accuracy: 0.9624 - graph_loss: 0.3398 Epoch 43/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1692 - accuracy: 0.9671 - graph_loss: 0.3483 Epoch 44/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1605 - accuracy: 0.9647 - graph_loss: 0.3437 Epoch 45/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1485 - accuracy: 0.9703 - graph_loss: 0.3338 Epoch 46/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1467 - accuracy: 0.9717 - graph_loss: 0.3405 Epoch 47/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1492 - accuracy: 0.9694 - graph_loss: 0.3466 Epoch 48/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1577 - accuracy: 0.9666 - graph_loss: 0.3338 Epoch 49/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1363 - accuracy: 0.9773 - graph_loss: 0.3424 Epoch 50/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1511 - accuracy: 0.9694 - graph_loss: 0.3402 Epoch 51/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1366 - accuracy: 0.9759 - graph_loss: 0.3385 Epoch 52/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1254 - accuracy: 0.9777 - graph_loss: 0.3474 Epoch 53/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1289 - accuracy: 0.9740 - graph_loss: 0.3469 Epoch 54/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1410 - accuracy: 0.9689 - graph_loss: 0.3475 Epoch 55/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1356 - accuracy: 0.9703 - graph_loss: 0.3483 Epoch 56/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1283 - accuracy: 0.9773 - graph_loss: 0.3412 Epoch 57/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1264 - accuracy: 0.9745 - graph_loss: 0.3473 Epoch 58/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1242 - accuracy: 0.9740 - graph_loss: 0.3443 Epoch 59/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1144 - accuracy: 0.9782 - graph_loss: 0.3440 Epoch 60/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1250 - accuracy: 0.9735 - graph_loss: 0.3357 Epoch 61/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1190 - accuracy: 0.9787 - graph_loss: 0.3400 Epoch 62/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1141 - accuracy: 0.9814 - graph_loss: 0.3419 Epoch 63/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1085 - accuracy: 0.9787 - graph_loss: 0.3395 Epoch 64/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1148 - accuracy: 0.9768 - graph_loss: 0.3504 Epoch 65/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1137 - accuracy: 0.9791 - graph_loss: 0.3360 Epoch 66/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1121 - accuracy: 0.9745 - graph_loss: 0.3469 Epoch 67/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1046 - accuracy: 0.9810 - graph_loss: 0.3476 Epoch 68/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1112 - accuracy: 0.9791 - graph_loss: 0.3431 Epoch 69/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1075 - accuracy: 0.9787 - graph_loss: 0.3455 Epoch 70/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0986 - accuracy: 0.9875 - graph_loss: 0.3403 Epoch 71/100 17/17 [==============================] - 0s 9ms/step - loss: 0.1141 - accuracy: 0.9782 - graph_loss: 0.3508 Epoch 72/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1012 - accuracy: 0.9814 - graph_loss: 0.3453 Epoch 73/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9833 - graph_loss: 0.3430 Epoch 74/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9842 - graph_loss: 0.3447 Epoch 75/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0988 - accuracy: 0.9842 - graph_loss: 0.3430 Epoch 76/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0915 - accuracy: 0.9856 - graph_loss: 0.3475 Epoch 77/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0960 - accuracy: 0.9833 - graph_loss: 0.3353 Epoch 78/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0916 - accuracy: 0.9838 - graph_loss: 0.3441 Epoch 79/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0979 - accuracy: 0.9800 - graph_loss: 0.3476 Epoch 80/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9782 - graph_loss: 0.3400 Epoch 81/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0978 - accuracy: 0.9838 - graph_loss: 0.3386 Epoch 82/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0994 - accuracy: 0.9805 - graph_loss: 0.3416 Epoch 83/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0957 - accuracy: 0.9838 - graph_loss: 0.3398 Epoch 84/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9879 - graph_loss: 0.3379 Epoch 85/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9838 - graph_loss: 0.3441 Epoch 86/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0906 - accuracy: 0.9847 - graph_loss: 0.3445 Epoch 87/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0891 - accuracy: 0.9852 - graph_loss: 0.3506 Epoch 88/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0821 - accuracy: 0.9898 - graph_loss: 0.3448 Epoch 89/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0803 - accuracy: 0.9865 - graph_loss: 0.3370 Epoch 90/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0896 - accuracy: 0.9828 - graph_loss: 0.3428 Epoch 91/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0887 - accuracy: 0.9852 - graph_loss: 0.3505 Epoch 92/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0882 - accuracy: 0.9847 - graph_loss: 0.3396 Epoch 93/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0807 - accuracy: 0.9879 - graph_loss: 0.3473 Epoch 94/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0820 - accuracy: 0.9861 - graph_loss: 0.3367 Epoch 95/100 17/17 [==============================] - 0s 9ms/step - loss: 0.0864 - accuracy: 0.9838 - graph_loss: 0.3353 Epoch 96/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0786 - accuracy: 0.9889 - graph_loss: 0.3392 Epoch 97/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0735 - accuracy: 0.9912 - graph_loss: 0.3443 Epoch 98/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0861 - accuracy: 0.9842 - graph_loss: 0.3381 Epoch 99/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0850 - accuracy: 0.9833 - graph_loss: 0.3376 Epoch 100/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0841 - accuracy: 0.9879 - graph_loss: 0.3510 <tensorflow.python.keras.callbacks.History at 0x7fc3d853ce10>
הערך את מודל ה- MLP באמצעות ויסות גרפי
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 6ms/step - loss: 1.2475 - accuracy: 0.8192 Eval accuracy for MLP + graph regularization : 0.8191681504249573 Eval loss for MLP + graph regularization : 1.2474583387374878
דיוק המודל base_model
גבוה בכ- 2-3% מזה של מודל הבסיס ( base_model
).
סיכום
הדגמנו את השימוש בהסדרת גרפים לסיווג מסמכים בגרף ציטוטים טבעי (Cora) באמצעות מסגרת למידה מובנית עצבית (NSL). ההדרכה המתקדמת שלנו כוללת סינתזת גרפים על סמך הטמאות לדוגמה לפני אימון רשת עצבית עם ויסות גרפים. גישה זו שימושית אם הקלט אינו מכיל גרף מפורש.
אנו ממליצים למשתמשים להתנסות בהמשך על ידי שינוי בכמות הפיקוח וכן לנסות ארכיטקטורות עצביות שונות להסדרת גרפים.