کاوش مفصل گردنده TF-Hub CORD-19

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده در GitHub دانلود دفترچه یادداشت به مدل TF Hub مراجعه کنید

ماژول جاسازی متن CORD-19 Swivel از TF-Hub (https: //tfhub.dev/tensorflow/cord-19/swivel-128d/1) برای حمایت از پژوهشگران در نوشتن متن در ارتباط با متن کووید -19. این موارد تعبیه شده در مورد عناوین ، نویسندگان ، چکیده مقاله ، متن اصلی و عناوین مرجع مقالات در مجموعه داده های CORD-19 آموزش داده شد .

در این کولاب ما:

  • کلمات مشابه را در فضای تعبیه تجزیه و تحلیل کنید
  • با استفاده از تعبیه های CORD-19 ، یک طبقه بندی را روی مجموعه داده های SciCite آموزش دهید

برپایی

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.logging.set_verbosity('ERROR')

import tensorflow_datasets as tfds
import tensorflow_hub as hub

try:
  from google.colab import data_table
  def display_df(df):
    return data_table.DataTable(df, include_index=False)
except ModuleNotFoundError:
  # If google-colab is not available, just display the raw DataFrame
  def display_df(df):
    return df

تجزیه و تحلیل تعبیه شده ها

بیایید شروع کنیم با تجزیه و تحلیل تعبیه شده با محاسبه و رسم ماتریس همبستگی بین اصطلاحات مختلف. اگر تعبیه شده یادگیری موفقیت آمیز معنای کلمات مختلف را فراگرفته باشد ، بردارهای تعبیه شده کلمات معنایی مشابه باید نزدیک یکدیگر باشند. بیایید نگاهی به برخی اصطلاحات مربوط به COVID-19 بیندازیم.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)


with tf.Graph().as_default():
  # Load the module
  query_input = tf.placeholder(tf.string)
  module = hub.Module('https://tfhub.dev/tensorflow/cord-19/swivel-128d/1')
  embeddings = module(query_input)

  with tf.train.MonitoredTrainingSession() as sess:

    # Generate embeddings for some terms
    queries = [
        # Related viruses
        "coronavirus", "SARS", "MERS",
        # Regions
        "Italy", "Spain", "Europe",
        # Symptoms
        "cough", "fever", "throat"
    ]

    features = sess.run(embeddings, feed_dict={query_input: queries})
    plot_correlation(queries, features)

png

می توانیم ببینیم که تعبیه شده با موفقیت معنای اصطلاحات مختلف را جلب کرد. هر کلمه شبیه کلمات دیگر خوشه آن است (یعنی "ویروس کرونا" بسیار با "SARS" و "MERS" ارتباط دارد) ، در حالی که آنها از نظر خوشه های دیگر متفاوت هستند (یعنی شباهت "SARS" با "اسپانیا" نزدیک به 0).

حال بیایید ببینیم که چگونه می توانیم از این تعبیه ها برای حل یک کار خاص استفاده کنیم.

SciCite: طبقه بندی هدف استناد

این بخش نشان می دهد که چگونه می توان از تعبیه شده برای کارهای پایین دستی مانند طبقه بندی متن استفاده کرد. ما برای طبقه بندی اهداف استناد در مقالات دانشگاهی از مجموعه داده های SciCite از مجموعه داده های TensorFlow استفاده خواهیم کرد . با توجه به جمله ای با استناد به مقاله علمی ، طبقه بندی کنید که آیا هدف اصلی استناد ، اطلاعات زمینه ای ، استفاده از روش ها یا مقایسه نتایج است.

مجموعه داده را از TFDS تنظیم کنید

Downloading and preparing dataset scicite/1.0.0 (download: 22.12 MiB, generated: Unknown size, total: 22.12 MiB) to /home/kbuilder/tensorflow_datasets/scicite/1.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteHWK5SE/scicite-train.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteHWK5SE/scicite-validation.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteHWK5SE/scicite-test.tfrecord
Dataset scicite downloaded and prepared to /home/kbuilder/tensorflow_datasets/scicite/1.0.0. Subsequent calls will reuse this data.

بیایید نگاهی به چند نمونه از برچسب های مجموعه آموزشی بیندازیم

آموزش طبقه بندی قصد citaton

ما یک طبقه بندی کننده را با استفاده از Estimator در مجموعه داده های SciCite آموزش خواهیم داد . بیایید input_fns را تنظیم کنیم تا مجموعه داده را در مدل بخوانیم

def preprocessed_input_fn(for_eval):
  data = THE_DATASET.get_data(for_eval=for_eval)
  data = data.map(THE_DATASET.example_fn, num_parallel_calls=1)
  return data


def input_fn_train(params):
  data = preprocessed_input_fn(for_eval=False)
  data = data.repeat(None)
  data = data.shuffle(1024)
  data = data.batch(batch_size=params['batch_size'])
  return data


def input_fn_eval(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.repeat(1)
  data = data.batch(batch_size=params['batch_size'])
  return data


def input_fn_predict(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.batch(batch_size=params['batch_size'])
  return data

بیایید مدلی بسازیم که از جاسازی های CORD-19 با یک لایه طبقه بندی در بالا استفاده می کند.

def model_fn(features, labels, mode, params):
  # Embed the text
  embed = hub.Module(params['module_name'], trainable=params['trainable_module'])
  embeddings = embed(features['feature'])

  # Add a linear layer on top
  logits = tf.layers.dense(
      embeddings, units=THE_DATASET.num_classes(), activation=None)
  predictions = tf.argmax(input=logits, axis=1)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions={
            'logits': logits,
            'predictions': predictions,
            'features': features['feature'],
            'labels': features['label']
        })

  # Set up a multi-class classification head
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits)
  loss = tf.reduce_mean(loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=params['learning_rate'])
    train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

  elif mode == tf.estimator.ModeKeys.EVAL:
    accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
    precision = tf.metrics.precision(labels=labels, predictions=predictions)
    recall = tf.metrics.recall(labels=labels, predictions=predictions)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        eval_metric_ops={
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
        })

ابرپارامترها

مدل را آموزش دهید و ارزیابی کنید

بیایید مدل را آموزش دهیم و ارزیابی کنیم تا عملکرد روی وظیفه SciCite را ببینیم

estimator = tf.estimator.Estimator(functools.partial(model_fn, params=params))
metrics = []

for step in range(0, STEPS, EVAL_EVERY):
  estimator.train(input_fn=functools.partial(input_fn_train, params=params), steps=EVAL_EVERY)
  step_metrics = estimator.evaluate(input_fn=functools.partial(input_fn_eval, params=params))
  print('Global step {}: loss {:.3f}, accuracy {:.3f}'.format(step, step_metrics['loss'], step_metrics['accuracy']))
  metrics.append(step_metrics)
Global step 0: loss 0.796, accuracy 0.670
Global step 200: loss 0.701, accuracy 0.732
Global step 400: loss 0.682, accuracy 0.719
Global step 600: loss 0.650, accuracy 0.747
Global step 800: loss 0.620, accuracy 0.762
Global step 1000: loss 0.609, accuracy 0.762
Global step 1200: loss 0.605, accuracy 0.762
Global step 1400: loss 0.585, accuracy 0.783
Global step 1600: loss 0.586, accuracy 0.768
Global step 1800: loss 0.577, accuracy 0.774
Global step 2000: loss 0.584, accuracy 0.765
Global step 2200: loss 0.565, accuracy 0.778
Global step 2400: loss 0.570, accuracy 0.776
Global step 2600: loss 0.556, accuracy 0.789
Global step 2800: loss 0.563, accuracy 0.778
Global step 3000: loss 0.557, accuracy 0.784
Global step 3200: loss 0.566, accuracy 0.774
Global step 3400: loss 0.552, accuracy 0.782
Global step 3600: loss 0.551, accuracy 0.785
Global step 3800: loss 0.547, accuracy 0.788
Global step 4000: loss 0.549, accuracy 0.784
Global step 4200: loss 0.548, accuracy 0.785
Global step 4400: loss 0.553, accuracy 0.783
Global step 4600: loss 0.543, accuracy 0.786
Global step 4800: loss 0.548, accuracy 0.783
Global step 5000: loss 0.547, accuracy 0.785
Global step 5200: loss 0.539, accuracy 0.791
Global step 5400: loss 0.546, accuracy 0.782
Global step 5600: loss 0.548, accuracy 0.781
Global step 5800: loss 0.540, accuracy 0.791
Global step 6000: loss 0.542, accuracy 0.790
Global step 6200: loss 0.539, accuracy 0.792
Global step 6400: loss 0.545, accuracy 0.788
Global step 6600: loss 0.552, accuracy 0.781
Global step 6800: loss 0.549, accuracy 0.783
Global step 7000: loss 0.540, accuracy 0.788
Global step 7200: loss 0.543, accuracy 0.782
Global step 7400: loss 0.541, accuracy 0.787
Global step 7600: loss 0.532, accuracy 0.790
Global step 7800: loss 0.537, accuracy 0.792
global_steps = [x['global_step'] for x in metrics]
fig, axes = plt.subplots(ncols=2, figsize=(20,8))

for axes_index, metric_names in enumerate([['accuracy', 'precision', 'recall'],
                                            ['loss']]):
  for metric_name in metric_names:
    axes[axes_index].plot(global_steps, [x[metric_name] for x in metrics], label=metric_name)
  axes[axes_index].legend()
  axes[axes_index].set_xlabel("Global Step")

png

می توانیم ببینیم که ضرر به سرعت کاهش می یابد در حالی که به ویژه دقت به سرعت افزایش می یابد. بیایید چند نمونه برای بررسی نحوه ارتباط پیش بینی با برچسب های واقعی ترسیم کنیم:

predictions = estimator.predict(functools.partial(input_fn_predict, params))
first_10_predictions = list(itertools.islice(predictions, 10))

display_df(
  pd.DataFrame({
      TEXT_FEATURE_NAME: [pred['features'].decode('utf8') for pred in first_10_predictions],
      LABEL_NAME: [THE_DATASET.class_names()[pred['labels']] for pred in first_10_predictions],
      'prediction': [THE_DATASET.class_names()[pred['predictions']] for pred in first_10_predictions]
  }))

می توانیم ببینیم که برای این نمونه تصادفی ، مدل در اکثر مواقع برچسب صحیح را پیش بینی می کند ، نشان می دهد که می تواند جملات علمی را به خوبی جاسازی کند.

دیگه چیه؟

اکنون که کمی بیشتر با جاسازی های CORD-19 Swivel از TF-Hub آشنا شدید ، ما شما را به شرکت در مسابقه CORD-19 Kaggle جهت کمک به کسب بینش علمی از متون دانشگاهی مرتبط با COVID-19 تشویق می کنیم.