探索 TF-Hub CORD-19 Swivel 嵌入向量

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook 查看 TF Hub 模型

TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/3) 上的 CORD-19 Swivel 文本嵌入向量模块旨在支持研究员分析与 COVID-19 相关的自然语言文本。这些嵌入针对 CORD-19 数据集中文章的标题、作者、摘要、正文文本和参考文献标题进行了训练。

在此 Colab 中,我们将进行以下操作:

  • 分析嵌入向量空间中语义相似的单词
  • 使用 CORD-19 嵌入向量在 SciCite 数据集上训练分类器

设置

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange

分析嵌入向量

首先,我们通过计算和绘制不同术语之间的相关矩阵来分析嵌入向量。如果嵌入向量学会了成功捕获不同单词的含义,则语义相似的单词的嵌入向量应相互靠近。我们来看一些与 COVID-19 相关的术语。

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

png

可以看到,嵌入向量成功捕获了不同术语的含义。每个单词都与其所在簇的其他单词相似(即“coronavirus”与“SARS”和“MERS”高度相关),但与其他簇的术语不同(即“SARS”与“Spain”之间的相似度接近于 0)。

现在,我们来看看如何使用这些嵌入向量解决特定任务。

SciCite:引用意图分类

本部分介绍了将嵌入向量用于下游任务(如文本分类)的方法。我们将使用 TensorFlow 数据集中的 SciCite 数据集对学术论文中的引文意图进行分类。给定一个带有学术论文引文的句子,对引文的主要意图进行分类:是背景信息、使用方法,还是比较结果。

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

Let's take a look at a few labeled examples from the training set

训练引用意图分类器

我们将使用 Keras 在 SciCite 数据集上对分类器进行训练。我们构建一个模型,该模型使用 CORD-19 嵌入向量,并在顶部具有一个分类层。

Hyperparameters

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer (KerasLayer)    (None, 128)               17301632  
                                                                 
 dense (Dense)               (None, 3)                 387       
                                                                 
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

训练并评估模型

让我们训练并评估模型以查看在 SciCite 任务上的性能。

EPOCHS = 35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 2s 5ms/step - loss: 0.9287 - accuracy: 0.5848 - val_loss: 0.7893 - val_accuracy: 0.6769
Epoch 2/35
257/257 [==============================] - 2s 5ms/step - loss: 0.7066 - accuracy: 0.7103 - val_loss: 0.6824 - val_accuracy: 0.7260
Epoch 3/35
257/257 [==============================] - 2s 5ms/step - loss: 0.6291 - accuracy: 0.7514 - val_loss: 0.6319 - val_accuracy: 0.7478
Epoch 4/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5917 - accuracy: 0.7670 - val_loss: 0.6090 - val_accuracy: 0.7544
Epoch 5/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5711 - accuracy: 0.7783 - val_loss: 0.5943 - val_accuracy: 0.7587
Epoch 6/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5574 - accuracy: 0.7823 - val_loss: 0.5815 - val_accuracy: 0.7653
Epoch 7/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5476 - accuracy: 0.7846 - val_loss: 0.5744 - val_accuracy: 0.7729
Epoch 8/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5404 - accuracy: 0.7870 - val_loss: 0.5696 - val_accuracy: 0.7707
Epoch 9/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5358 - accuracy: 0.7896 - val_loss: 0.5677 - val_accuracy: 0.7697
Epoch 10/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5312 - accuracy: 0.7917 - val_loss: 0.5635 - val_accuracy: 0.7718
Epoch 11/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5268 - accuracy: 0.7919 - val_loss: 0.5617 - val_accuracy: 0.7718
Epoch 12/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5241 - accuracy: 0.7928 - val_loss: 0.5611 - val_accuracy: 0.7697
Epoch 13/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5217 - accuracy: 0.7939 - val_loss: 0.5575 - val_accuracy: 0.7718
Epoch 14/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5191 - accuracy: 0.7953 - val_loss: 0.5553 - val_accuracy: 0.7762
Epoch 15/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5169 - accuracy: 0.7957 - val_loss: 0.5532 - val_accuracy: 0.7806
Epoch 16/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5153 - accuracy: 0.7952 - val_loss: 0.5545 - val_accuracy: 0.7697
Epoch 17/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5138 - accuracy: 0.7962 - val_loss: 0.5518 - val_accuracy: 0.7828
Epoch 18/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5125 - accuracy: 0.7955 - val_loss: 0.5525 - val_accuracy: 0.7784
Epoch 19/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5110 - accuracy: 0.7975 - val_loss: 0.5516 - val_accuracy: 0.7762
Epoch 20/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5097 - accuracy: 0.7969 - val_loss: 0.5514 - val_accuracy: 0.7740
Epoch 21/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5090 - accuracy: 0.7969 - val_loss: 0.5491 - val_accuracy: 0.7762
Epoch 22/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5076 - accuracy: 0.7991 - val_loss: 0.5513 - val_accuracy: 0.7773
Epoch 23/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5067 - accuracy: 0.7985 - val_loss: 0.5488 - val_accuracy: 0.7806
Epoch 24/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5067 - accuracy: 0.7981 - val_loss: 0.5475 - val_accuracy: 0.7806
Epoch 25/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5048 - accuracy: 0.7995 - val_loss: 0.5472 - val_accuracy: 0.7795
Epoch 26/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5040 - accuracy: 0.8010 - val_loss: 0.5480 - val_accuracy: 0.7817
Epoch 27/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5036 - accuracy: 0.8006 - val_loss: 0.5454 - val_accuracy: 0.7838
Epoch 28/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5033 - accuracy: 0.8018 - val_loss: 0.5474 - val_accuracy: 0.7849
Epoch 29/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5025 - accuracy: 0.7996 - val_loss: 0.5459 - val_accuracy: 0.7838
Epoch 30/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5018 - accuracy: 0.8013 - val_loss: 0.5468 - val_accuracy: 0.7871
Epoch 31/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5016 - accuracy: 0.7997 - val_loss: 0.5475 - val_accuracy: 0.7860
Epoch 32/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5009 - accuracy: 0.8005 - val_loss: 0.5461 - val_accuracy: 0.7893
Epoch 33/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5006 - accuracy: 0.8017 - val_loss: 0.5459 - val_accuracy: 0.7882
Epoch 34/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5000 - accuracy: 0.8021 - val_loss: 0.5459 - val_accuracy: 0.7882
Epoch 35/35
257/257 [==============================] - 1s 5ms/step - loss: 0.4998 - accuracy: 0.8031 - val_loss: 0.5460 - val_accuracy: 0.7893
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

png

评估模型

我们来看看模型的表现。模型将返回两个值:损失(表示错误的数字,值越低越好)和准确率。

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5401 - accuracy: 0.7843 - 339ms/epoch - 85ms/step
loss: 0.540
accuracy: 0.784

可以看到,损失迅速减小,而准确率迅速提高。我们绘制一些样本来检查预测与真实标签的关系:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [
    label2str(x) for x in np.argmax(model.predict(prediction_texts), axis=-1)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
1/1 [==============================] - 0s 125ms/step

可以看到,对于此随机样本,模型大多数时候都会预测正确的标签,这表明它可以很好地嵌入科学句子。

后续计划

现在,您已经对 TF-Hub 中的 CORD-19 Swivel 嵌入向量有了更多了解,我们鼓励您参加 CORD-19 Kaggle 竞赛,为从 COVID-19 相关学术文本中获得更深入的科学洞见做出贡献。