探索 TF-Hub CORD-19 Swivel 嵌入向量

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本 查看 TF Hub 模型

TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/3) 上的 CORD-19 Swivel 文本嵌入向量模块旨在支持研究人员分析与 COVID-19 相关的自然语言文本。这些嵌入针对 CORD-19 数据集中文章的标题、作者、摘要、正文文本和参考文献标题进行了训练。

在此 Colab 中,我们将进行以下操作:

  • 分析嵌入向量空间中语义相似的单词
  • 使用 CORD-19 嵌入向量在 SciCite 数据集上训练分类器

设置

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange
2022-12-14 20:29:43.485341: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:29:43.485455: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:29:43.485466: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

分析嵌入向量

首先,我们通过计算和绘制不同术语之间的相关矩阵来分析嵌入向量。如果嵌入向量学会了成功捕获不同单词的含义,则语义相似的单词的嵌入向量应相互靠近。我们来看一些与 COVID-19 相关的术语。

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

png

可以看到,嵌入向量成功捕获了不同术语的含义。每个单词都与其所在簇的其他单词相似(即“coronavirus”与“SARS”和“MERS”高度相关),但与其他簇的术语不同(即“SARS”与“Spain”之间的相似度接近于 0)。

现在,我们来看看如何使用这些嵌入向量解决特定任务。

SciCite:引用意图分类

本部分介绍了将嵌入向量用于下游任务(如文本分类)的方法。我们将使用 TensorFlow 数据集中的 SciCite 数据集对学术论文中的引文意图进行分类。给定一个带有学术论文引文的句子,对引文的主要意图进行分类:是背景信息、使用方法,还是比较结果。

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

Let's take a look at a few labeled examples from the training set

训练引用意图分类器

我们将使用 Keras 在 SciCite 数据集上对分类器进行训练。我们构建一个模型,该模型使用 CORD-19 嵌入向量,并在顶部具有一个分类层。

Hyperparameters

WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer (KerasLayer)    (None, 128)               17301632  
                                                                 
 dense (Dense)               (None, 3)                 387       
                                                                 
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

训练并评估模型

让我们训练并评估模型以查看在 SciCite 任务上的性能。

EPOCHS = 35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 3s 5ms/step - loss: 1.0181 - accuracy: 0.5166 - val_loss: 0.8177 - val_accuracy: 0.6550
Epoch 2/35
257/257 [==============================] - 1s 4ms/step - loss: 0.7353 - accuracy: 0.7049 - val_loss: 0.7010 - val_accuracy: 0.7096
Epoch 3/35
257/257 [==============================] - 1s 4ms/step - loss: 0.6503 - accuracy: 0.7447 - val_loss: 0.6482 - val_accuracy: 0.7336
Epoch 4/35
257/257 [==============================] - 2s 4ms/step - loss: 0.6073 - accuracy: 0.7632 - val_loss: 0.6179 - val_accuracy: 0.7434
Epoch 5/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5819 - accuracy: 0.7742 - val_loss: 0.5979 - val_accuracy: 0.7544
Epoch 6/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5655 - accuracy: 0.7796 - val_loss: 0.5873 - val_accuracy: 0.7620
Epoch 7/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5546 - accuracy: 0.7818 - val_loss: 0.5784 - val_accuracy: 0.7653
Epoch 8/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5461 - accuracy: 0.7873 - val_loss: 0.5722 - val_accuracy: 0.7653
Epoch 9/35
257/257 [==============================] - 2s 4ms/step - loss: 0.5393 - accuracy: 0.7881 - val_loss: 0.5714 - val_accuracy: 0.7642
Epoch 10/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5346 - accuracy: 0.7907 - val_loss: 0.5644 - val_accuracy: 0.7729
Epoch 11/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5306 - accuracy: 0.7897 - val_loss: 0.5616 - val_accuracy: 0.7751
Epoch 12/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5263 - accuracy: 0.7918 - val_loss: 0.5580 - val_accuracy: 0.7784
Epoch 13/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5238 - accuracy: 0.7935 - val_loss: 0.5563 - val_accuracy: 0.7806
Epoch 14/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5210 - accuracy: 0.7934 - val_loss: 0.5561 - val_accuracy: 0.7784
Epoch 15/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5182 - accuracy: 0.7955 - val_loss: 0.5576 - val_accuracy: 0.7762
Epoch 16/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5169 - accuracy: 0.7930 - val_loss: 0.5511 - val_accuracy: 0.7806
Epoch 17/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5145 - accuracy: 0.7956 - val_loss: 0.5515 - val_accuracy: 0.7784
Epoch 18/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5130 - accuracy: 0.7956 - val_loss: 0.5511 - val_accuracy: 0.7740
Epoch 19/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5115 - accuracy: 0.7969 - val_loss: 0.5495 - val_accuracy: 0.7795
Epoch 20/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5112 - accuracy: 0.7955 - val_loss: 0.5504 - val_accuracy: 0.7860
Epoch 21/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5090 - accuracy: 0.7980 - val_loss: 0.5485 - val_accuracy: 0.7773
Epoch 22/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5083 - accuracy: 0.7989 - val_loss: 0.5496 - val_accuracy: 0.7817
Epoch 23/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5066 - accuracy: 0.7969 - val_loss: 0.5476 - val_accuracy: 0.7860
Epoch 24/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5067 - accuracy: 0.7977 - val_loss: 0.5461 - val_accuracy: 0.7806
Epoch 25/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5055 - accuracy: 0.7983 - val_loss: 0.5472 - val_accuracy: 0.7882
Epoch 26/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5048 - accuracy: 0.7970 - val_loss: 0.5492 - val_accuracy: 0.7817
Epoch 27/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5043 - accuracy: 0.8002 - val_loss: 0.5504 - val_accuracy: 0.7806
Epoch 28/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5034 - accuracy: 0.7988 - val_loss: 0.5464 - val_accuracy: 0.7871
Epoch 29/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5029 - accuracy: 0.7995 - val_loss: 0.5462 - val_accuracy: 0.7871
Epoch 30/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5021 - accuracy: 0.8019 - val_loss: 0.5489 - val_accuracy: 0.7860
Epoch 31/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5018 - accuracy: 0.8013 - val_loss: 0.5471 - val_accuracy: 0.7860
Epoch 32/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5010 - accuracy: 0.8019 - val_loss: 0.5455 - val_accuracy: 0.7915
Epoch 33/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5007 - accuracy: 0.8021 - val_loss: 0.5469 - val_accuracy: 0.7904
Epoch 34/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5001 - accuracy: 0.8013 - val_loss: 0.5479 - val_accuracy: 0.7871
Epoch 35/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5004 - accuracy: 0.8028 - val_loss: 0.5455 - val_accuracy: 0.7882
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)
/tmpfs/tmp/ipykernel_21774/4094752860.py:6: MatplotlibDeprecationWarning: Auto-removal of overlapping axes is deprecated since 3.6 and will be removed two minor releases later; explicitly call ax.remove() as needed.
  ax = plt.subplot(subplot)

png

评估模型

我们来看看模型的表现。模型将返回两个值:损失(表示错误的数字,值越低越好)和准确率。

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5379 - accuracy: 0.7848 - 291ms/epoch - 73ms/step
loss: 0.538
accuracy: 0.785

可以看到,损失迅速减小,而准确率迅速提高。我们绘制一些样本来检查预测与真实标签的关系:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [
    label2str(x) for x in np.argmax(model.predict(prediction_texts), axis=-1)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
1/1 [==============================] - 0s 150ms/step

可以看到,对于此随机样本,模型大多数时候都会预测正确的标签,这表明它可以很好地嵌入科学句子。

后续计划

现在,您已经对 TF-Hub 中的 CORD-19 Swivel 嵌入向量有了更多了解,我们鼓励您参加 CORD-19 Kaggle 竞赛,为从 COVID-19 相关学术文本中获得更深入的科学洞见做出贡献。