ויסות גרפים לסיווג מסמכים באמצעות גרפים טבעיים

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

סקירה כללית

הסדרת הגרף היא טכניקה ספציפית תחת הפרדיגמה הרחבה של עצבי גרף למידה ( בוי et al., 2018 ). הרעיון המרכזי הוא להכשיר מודלים של רשתות עצביות עם מטרה מסודרת בגרף, תוך שימוש בנתונים מסומנים וגם ללא תווית.

במדריך זה, נחקור את השימוש בהסדרת גרפים כדי לסווג מסמכים היוצרים גרף טבעי (אורגני).

המתכון הכללי ליצירת מודל מוסדר גרף באמצעות המסגרת של למידה מובנית עצבית (NSL) הוא כדלקמן:

  1. צור נתוני אימון מגרף הקלט ותכונות לדוגמה. צמתים בגרף מתאימים לדגימות והקצוות בגרף מתאימים לדמיון בין זוגות דגימות. נתוני האימון שיתקבלו יכילו תכונות שכנות בנוסף לתכונות הצומת המקוריות.
  2. צור רשת עצבית כמודל בסיס באמצעות Keras הרציף, הפונקציונלי, או API תת.
  3. עוטף את דגם הבסיס עם GraphRegularization מעמד המעטפת, אשר מסופק על ידי מסגרת NSL, כדי ליצור גרף חדש Keras מודל. מודל חדש זה יכלול אובדן הסדרת גרף כמונח ההסדרה ביעד ההכשרה שלו.
  4. רכבת ולהעריך את הגרף Keras מודל.

להכין

התקן את חבילת הלמידה המובנית העצבית.

pip install --quiet neural-structured-learning

תלות ויבוא

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version:  2.8.0-rc0
Eager mode:  True
GPU is NOT AVAILABLE
2022-01-05 12:39:27.704660: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

מערך נתונים של Cora

במערך קורה הוא גרף הציטוט שבו הצמתים מייצגים ניירות למידת מכונה וקצוות מייצגים ציטוטים בין זוגות של ניירות. המשימה הכרוכה היא סיווג מסמכים כאשר המטרה היא לסווג כל מאמר לאחת מ-7 קטגוריות. במילים אחרות, מדובר בבעיית סיווג רב מחלקות עם 7 מחלקות.

גרָף

הגרף המקורי מכוון. עם זאת, לצורך דוגמה זו, אנו רואים את הגרסה הבלתי מכוונת של גרף זה. לכן, אם מאמר א' מצטט את מאמר ב', אנו מחשיבים את המאמר ב' כמצוטט את א'. למרות שזה לא בהכרח נכון, בדוגמה זו, אנו רואים ציטוטים בתור פרוקסי לדמיון, שהוא בדרך כלל תכונה קומוטטיבית.

מאפיינים

כל נייר בקלט מכיל למעשה 2 תכונות:

  1. מילים: צפוף, רב-חם שקית-של-מילים ייצוג של הטקסט בעיתון. אוצר המילים עבור מערך הנתונים של Cora מכיל 1433 מילים ייחודיות. אז, אורכה של תכונה זו הוא 1433, והערך במיקום 'i' הוא 0/1 המציין אם המילה 'i' באוצר המילים קיימת במאמר הנתון או לא.

  2. לייבל: שלם בודד המייצג את הזהות בכיתה (קטגוריה) של נייר.

הורד את מערך הנתונים של Cora

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

המר את נתוני Cora לפורמט NSL

כדי preprocess במערך קורה ולהמיר אותו לפורמט הנדרש על ידי למידה Structured עצבית, נוכל להריץ את הסקריפט "preprocess_cora_dataset.py", אשר נכלל במאגר NSL GitHub. הסקריפט הזה עושה את הפעולות הבאות:

  1. צור תכונות שכנות באמצעות תכונות הצומת המקוריות והגרף.
  2. צור פיצולי נתוני הרכבת מבחן המכילים tf.train.Example מקרים.
  3. נמשך הרכבת שהתקבלה נתון בדיקה של TFRecord הפורמט.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2022-01-05 12:39:28--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0s      

2022-01-05 12:39:28 (78.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2022-01-05 12:39:31.378912: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.01 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.36 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.04 minutes.

משתנים גלובליים

שבילי קובץ נתון הרכבת ולבדוק מבוסס על ערכי סימון שורת פקוד המשמשים להפעיל את התסריט "preprocess_cora_dataset.py" לעיל.

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

היפרפרמטרים

נשתמש מופע של HParams לכלול hyperparameters וקבועים שונים שימשו לאימונים והערכה. אנו מתארים בקצרה כל אחד מהם להלן:

  • num_classes: ישנם סוגים שונים 7 סה"כ

  • max_seq_length: זהו הגודל של אוצר המילים ואת כל המופעים בקלט יש ייצוג רב-חם, שקית-של-מילים צפופה. במילים אחרות, ערך 1 למילה מציין שהמילה קיימת בקלט וערך 0 מציין שלא.

  • DISTANCE_TYPE: זהו המרחק מטרי המשמש להסדרת המדגם עם שכנותיה.

  • graph_regularization_multiplier: זו קובעת את המשקל היחסי של המונח להסדרת גרף פונקצית ההפסד הכולל.

  • num_neighbors: מספר שכנים משמשים להסדרת גרף. יש ערך זה להיות קטן או שווה ל max_nbrs שורת הפקודה טיעון בשימוש מעל בעת הפעלת preprocess_cora_dataset.py .

  • num_fc_units: מספר השכבות מחוברות באופן מלא ברשת העצבית שלנו.

  • train_epochs: מספר תקופות האימונים.

  • גודל יצווה שמש לאימונים והערכה: batch_size.

  • dropout_rate: בקרת שיעור הנשירה בעקבות כל שכבה מחוברת באופן מלא

  • eval_steps: מספר אצוות כדי תהליך לפני רואה סיום בדיקה. אם מוגדר None , כל המופעים בערכת בדיקה מוערכים.

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

טען נתוני רכבת ובדיקה

כפי שתואר קודם לכן במחברת זו, נתון הכשרת מבחן קלט נוצרו על ידי "preprocess_cora_dataset.py". אנו נטען אותם לשתי tf.data.Dataset אובייקטים - אחד עבור רכבת ואחד לבדיקה.

בשכבת הקלט של המודל שלנו, נוכל לחלץ ולא רק את "מילות" ואת "התווית" תכונות מכל מדגם, אלא גם שכן מקביל תכונות המבוסס על hparams.num_neighbors הערך. מקרים עם שכנים פחות מ hparams.num_neighbors יוקצו דמה ערכים עבור אפס שכן תכונות אלה.

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

בואו נציץ לתוך מערך הנתונים של הרכבת כדי להסתכל על תוכנו.

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[2 2 6 2 0 6 1 3 5 0 1 2 3 6 1 1 0 3 5 2 3 1 4 1 6 1 3 2 2 2 0 3 2 1 3 3 2
 3 3 2 3 2 2 0 2 2 6 0 2 1 1 0 5 2 1 4 2 1 2 4 0 2 5 4 3 6 3 2 1 6 2 4 2 2
 6 4 6 4 3 5 2 2 2 4 2 2 2 1 2 2 2 4 2 3 6 2 0 6 6 0 2 6 2 1 2 0 1 1 3 2 0
 2 0 2 1 1 3 5 2 1 2 5 1 6 2 4 6 4], shape=(128,), dtype=int64)

בואו להציץ לתוך מערך הנתונים של הבדיקה כדי להסתכל על תוכנו.

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

הגדרת דגם

על מנת להדגים את השימוש בהסדרת גרפים, אנו בונים תחילה מודל בסיס לבעיה זו. נשתמש ברשת עצבית פשוטה להזנה קדימה עם 2 שכבות נסתרות ונשירה ביניהן. אנחנו מדגימים את יצירתו של מודל הבסיס באמצעות כל סוגי מודל נתמך על ידי tf.Keras מסגרת - רציפים, פונקציונלי, ואת תת.

מודל בסיס רציף

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes))
  return model

דגם בסיס פונקציונלי

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

מודל בסיס תת-מעמדי

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(hparams.num_classes)

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

צור מודל/ים בסיסיים

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 words (InputLayer)          [(None, 1433)]            0         
                                                                 
 lambda (Lambda)             (None, 1433)              0         
                                                                 
 dense (Dense)               (None, 50)                71700     
                                                                 
 dropout (Dropout)           (None, 50)                0         
                                                                 
 dense_1 (Dense)             (None, 50)                2550      
                                                                 
 dropout_1 (Dropout)         (None, 50)                0         
                                                                 
 dense_2 (Dense)             (None, 7)                 357       
                                                                 
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

דגם MLP בסיס רכבת

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)
17/17 [==============================] - 1s 18ms/step - loss: 1.9521 - accuracy: 0.1838
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8590 - accuracy: 0.3044
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7770 - accuracy: 0.3601
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6655 - accuracy: 0.3898
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5386 - accuracy: 0.4543
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3856 - accuracy: 0.5077
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2736 - accuracy: 0.5531
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1636 - accuracy: 0.5889
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0654 - accuracy: 0.6385
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9703 - accuracy: 0.6761
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8689 - accuracy: 0.7104
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7704 - accuracy: 0.7494
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7157 - accuracy: 0.7810
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 0.6296 - accuracy: 0.8186
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5932 - accuracy: 0.8167
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5526 - accuracy: 0.8464
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5112 - accuracy: 0.8445
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8613
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4163 - accuracy: 0.8696
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3808 - accuracy: 0.8849
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3564 - accuracy: 0.8933
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3453 - accuracy: 0.9002
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3226 - accuracy: 0.9114
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3058 - accuracy: 0.9151
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2798 - accuracy: 0.9146
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2638 - accuracy: 0.9248
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2538 - accuracy: 0.9290
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2356 - accuracy: 0.9411
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2080 - accuracy: 0.9425
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2172 - accuracy: 0.9364
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2259 - accuracy: 0.9225
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1944 - accuracy: 0.9480
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1892 - accuracy: 0.9434
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1718 - accuracy: 0.9592
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1826 - accuracy: 0.9508
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1585 - accuracy: 0.9559
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1605 - accuracy: 0.9545
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1529 - accuracy: 0.9550
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1411 - accuracy: 0.9615
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1366 - accuracy: 0.9624
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1431 - accuracy: 0.9578
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1241 - accuracy: 0.9619
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1310 - accuracy: 0.9661
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1284 - accuracy: 0.9652
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1215 - accuracy: 0.9633
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1130 - accuracy: 0.9722
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1074 - accuracy: 0.9722
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1143 - accuracy: 0.9694
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1015 - accuracy: 0.9740
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1077 - accuracy: 0.9698
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1035 - accuracy: 0.9684
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9694
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1000 - accuracy: 0.9689
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0967 - accuracy: 0.9749
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0994 - accuracy: 0.9703
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0943 - accuracy: 0.9740
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0923 - accuracy: 0.9735
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0848 - accuracy: 0.9800
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0836 - accuracy: 0.9782
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0913 - accuracy: 0.9735
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0823 - accuracy: 0.9773
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0753 - accuracy: 0.9810
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0746 - accuracy: 0.9777
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0861 - accuracy: 0.9731
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0765 - accuracy: 0.9787
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0750 - accuracy: 0.9791
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0725 - accuracy: 0.9814
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 0.9791
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0645 - accuracy: 0.9842
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0606 - accuracy: 0.9861
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0775 - accuracy: 0.9805
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0655 - accuracy: 0.9800
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0629 - accuracy: 0.9833
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0625 - accuracy: 0.9824
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0607 - accuracy: 0.9838
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0578 - accuracy: 0.9824
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0568 - accuracy: 0.9842
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0595 - accuracy: 0.9833
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0615 - accuracy: 0.9842
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0555 - accuracy: 0.9852
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0517 - accuracy: 0.9870
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0541 - accuracy: 0.9856
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9884
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0509 - accuracy: 0.9838
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0600 - accuracy: 0.9828
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0617 - accuracy: 0.9800
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0599 - accuracy: 0.9800
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0502 - accuracy: 0.9870
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0416 - accuracy: 0.9907
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0542 - accuracy: 0.9842
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0490 - accuracy: 0.9847
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0374 - accuracy: 0.9916
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9893
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 0.9879
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0543 - accuracy: 0.9861
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0420 - accuracy: 0.9870
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0461 - accuracy: 0.9861
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0425 - accuracy: 0.9898
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0406 - accuracy: 0.9907
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0486 - accuracy: 0.9847
<keras.callbacks.History at 0x7f6f9d5eacd0>

הערכת מודל MLP בסיסי

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4192 - accuracy: 0.7939


Eval accuracy for  Base MLP model :  0.7938517332077026
Eval loss for  Base MLP model :  1.4192423820495605

הרכבת מודל MLP עם הסדרת גרפים

שילוב להסדרת גרף לתוך מונח האובדן קיים tf.Keras.Model דורש רק כמה שורות של קוד. מודל הבסיס עטוף ליצור חדש tf.Keras מודל תת, אשר אובדנו כולל להסדרת גרף.

כדי להעריך את היתרון המצטבר של הסדרת גרפים, ניצור מופע מודל בסיס חדש. הסיבה לכך היא base_model כבר מאומנת במשך כמה חזרות, ולשימוש חוזר מודל מודרך זה כדי ליצור מודל גרף-סדיר לא תהיה השוואה הוגנת עבור base_model .

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "shape. This may consume a large amount of memory." % value)
17/17 [==============================] - 2s 4ms/step - loss: 1.9798 - accuracy: 0.1601 - scaled_graph_loss: 0.0373
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.9024 - accuracy: 0.2979 - scaled_graph_loss: 0.0254
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8623 - accuracy: 0.3160 - scaled_graph_loss: 0.0317
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8042 - accuracy: 0.3443 - scaled_graph_loss: 0.0498
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7552 - accuracy: 0.3582 - scaled_graph_loss: 0.0696
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7012 - accuracy: 0.4084 - scaled_graph_loss: 0.0866
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6578 - accuracy: 0.4515 - scaled_graph_loss: 0.1114
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6058 - accuracy: 0.5039 - scaled_graph_loss: 0.1300
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5498 - accuracy: 0.5434 - scaled_graph_loss: 0.1508
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5098 - accuracy: 0.6019 - scaled_graph_loss: 0.1651
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4746 - accuracy: 0.6302 - scaled_graph_loss: 0.1844
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4315 - accuracy: 0.6520 - scaled_graph_loss: 0.1917
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3932 - accuracy: 0.6770 - scaled_graph_loss: 0.2024
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3645 - accuracy: 0.7183 - scaled_graph_loss: 0.2145
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3265 - accuracy: 0.7369 - scaled_graph_loss: 0.2324
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3045 - accuracy: 0.7555 - scaled_graph_loss: 0.2358
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2836 - accuracy: 0.7652 - scaled_graph_loss: 0.2404
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2456 - accuracy: 0.7898 - scaled_graph_loss: 0.2469
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2348 - accuracy: 0.8074 - scaled_graph_loss: 0.2615
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2000 - accuracy: 0.8074 - scaled_graph_loss: 0.2542
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1994 - accuracy: 0.8260 - scaled_graph_loss: 0.2729
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1825 - accuracy: 0.8269 - scaled_graph_loss: 0.2676
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1598 - accuracy: 0.8455 - scaled_graph_loss: 0.2742
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8534 - scaled_graph_loss: 0.2797
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1456 - accuracy: 0.8552 - scaled_graph_loss: 0.2714
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8566 - scaled_graph_loss: 0.2796
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1150 - accuracy: 0.8687 - scaled_graph_loss: 0.2850
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8626 - scaled_graph_loss: 0.2772
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0806 - accuracy: 0.8733 - scaled_graph_loss: 0.2756
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0828 - accuracy: 0.8626 - scaled_graph_loss: 0.2907
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0724 - accuracy: 0.8886 - scaled_graph_loss: 0.2834
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0589 - accuracy: 0.8826 - scaled_graph_loss: 0.2881
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0490 - accuracy: 0.8872 - scaled_graph_loss: 0.2972
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0550 - accuracy: 0.8923 - scaled_graph_loss: 0.2935
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0397 - accuracy: 0.8840 - scaled_graph_loss: 0.2795
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0360 - accuracy: 0.8891 - scaled_graph_loss: 0.2966
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0235 - accuracy: 0.8961 - scaled_graph_loss: 0.2890
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0219 - accuracy: 0.8984 - scaled_graph_loss: 0.2965
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0168 - accuracy: 0.9044 - scaled_graph_loss: 0.3023
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0148 - accuracy: 0.9035 - scaled_graph_loss: 0.2984
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9118 - scaled_graph_loss: 0.2888
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0019 - accuracy: 0.9021 - scaled_graph_loss: 0.2877
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9049 - scaled_graph_loss: 0.2912
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9986 - accuracy: 0.9026 - scaled_graph_loss: 0.3040
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9939 - accuracy: 0.9067 - scaled_graph_loss: 0.3016
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9828 - accuracy: 0.9058 - scaled_graph_loss: 0.2877
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9629 - accuracy: 0.9137 - scaled_graph_loss: 0.2844
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9645 - accuracy: 0.9146 - scaled_graph_loss: 0.2933
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9752 - accuracy: 0.9165 - scaled_graph_loss: 0.3013
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9552 - accuracy: 0.9179 - scaled_graph_loss: 0.2865
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9539 - accuracy: 0.9193 - scaled_graph_loss: 0.3044
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9443 - accuracy: 0.9183 - scaled_graph_loss: 0.3010
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9559 - accuracy: 0.9244 - scaled_graph_loss: 0.2987
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9497 - accuracy: 0.9225 - scaled_graph_loss: 0.2979
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9674 - accuracy: 0.9183 - scaled_graph_loss: 0.3034
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9537 - accuracy: 0.9174 - scaled_graph_loss: 0.2834
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9341 - accuracy: 0.9188 - scaled_graph_loss: 0.2939
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9392 - accuracy: 0.9225 - scaled_graph_loss: 0.2998
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9240 - accuracy: 0.9313 - scaled_graph_loss: 0.3022
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9368 - accuracy: 0.9267 - scaled_graph_loss: 0.2979
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9306 - accuracy: 0.9234 - scaled_graph_loss: 0.2952
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9197 - accuracy: 0.9230 - scaled_graph_loss: 0.2916
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9360 - accuracy: 0.9206 - scaled_graph_loss: 0.2947
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9181 - accuracy: 0.9299 - scaled_graph_loss: 0.2996
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9105 - accuracy: 0.9341 - scaled_graph_loss: 0.2981
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9014 - accuracy: 0.9323 - scaled_graph_loss: 0.2897
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9059 - accuracy: 0.9364 - scaled_graph_loss: 0.3083
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9053 - accuracy: 0.9309 - scaled_graph_loss: 0.2976
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9099 - accuracy: 0.9258 - scaled_graph_loss: 0.3069
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9025 - accuracy: 0.9355 - scaled_graph_loss: 0.2890
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8849 - accuracy: 0.9281 - scaled_graph_loss: 0.2933
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8959 - accuracy: 0.9323 - scaled_graph_loss: 0.2918
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9074 - accuracy: 0.9248 - scaled_graph_loss: 0.3065
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8845 - accuracy: 0.9369 - scaled_graph_loss: 0.2874
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8873 - accuracy: 0.9401 - scaled_graph_loss: 0.2996
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8942 - accuracy: 0.9327 - scaled_graph_loss: 0.3086
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9052 - accuracy: 0.9253 - scaled_graph_loss: 0.2986
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8811 - accuracy: 0.9336 - scaled_graph_loss: 0.2948
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8896 - accuracy: 0.9276 - scaled_graph_loss: 0.2919
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8853 - accuracy: 0.9313 - scaled_graph_loss: 0.2944
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8875 - accuracy: 0.9323 - scaled_graph_loss: 0.2925
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8639 - accuracy: 0.9323 - scaled_graph_loss: 0.2967
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8820 - accuracy: 0.9332 - scaled_graph_loss: 0.3047
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8752 - accuracy: 0.9346 - scaled_graph_loss: 0.2942
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9374 - scaled_graph_loss: 0.3066
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9332 - scaled_graph_loss: 0.2881
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8691 - accuracy: 0.9420 - scaled_graph_loss: 0.3030
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8631 - accuracy: 0.9374 - scaled_graph_loss: 0.2916
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9392 - scaled_graph_loss: 0.3032
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8632 - accuracy: 0.9420 - scaled_graph_loss: 0.3019
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8600 - accuracy: 0.9425 - scaled_graph_loss: 0.2965
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8569 - accuracy: 0.9346 - scaled_graph_loss: 0.2977
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8704 - accuracy: 0.9374 - scaled_graph_loss: 0.3083
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8562 - accuracy: 0.9406 - scaled_graph_loss: 0.2883
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8545 - accuracy: 0.9415 - scaled_graph_loss: 0.3030
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8592 - accuracy: 0.9332 - scaled_graph_loss: 0.2927
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8503 - accuracy: 0.9397 - scaled_graph_loss: 0.2927
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8434 - accuracy: 0.9462 - scaled_graph_loss: 0.2937
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8578 - accuracy: 0.9374 - scaled_graph_loss: 0.3064
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8504 - accuracy: 0.9411 - scaled_graph_loss: 0.3043
<keras.callbacks.History at 0x7f70041be650>

הערכת מודל MLP עם הסדרת גרפים

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8884 - accuracy: 0.7957


Eval accuracy for  MLP + graph regularization :  0.7956600189208984
Eval loss for  MLP + graph regularization :  0.8883611559867859

הדיוק של המודל סדיר-הגרף הוא כ 2-3% גבוהים יותר מזה של הדגם הבסיסי ( base_model ).

סיכום

הדגמנו את השימוש בהסדרת גרפים לסיווג מסמכים על גרף ציטוט טבעי (Cora) באמצעות המסגרת של למידה מובנית עצבית (NSL). שלנו הדרכה מתקדמת כרוכה סינתזה גרפים המבוססים על שיבוצים מדגם לפני אימון רשת עצבית עם הסדרת הגרף. גישה זו שימושית אם הקלט אינו מכיל גרף מפורש.

אנו ממליצים למשתמשים להתנסות נוספת על ידי שינוי כמות הפיקוח וכן ניסיון של ארכיטקטורות עצביות שונות להסדרת גרפים.