![]() |
![]() |
![]() |
![]() |
![]() |
TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/3) 上的 CORD-19 Swivel 文本嵌入向量模块旨在支持研究员分析与 COVID-19 相关的自然语言文本。这些嵌入针对 CORD-19 数据集中文章的标题、作者、摘要、正文文本和参考文献标题进行了训练。
在此 Colab 中,我们将进行以下操作:
- 分析嵌入向量空间中语义相似的单词
- 使用 CORD-19 嵌入向量在 SciCite 数据集上训练分类器
设置
import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tqdm import trange
分析嵌入向量
首先,我们通过计算和绘制不同术语之间的相关矩阵来分析嵌入向量。如果嵌入向量学会了成功捕获不同单词的含义,则语义相似的单词的嵌入向量应相互靠近。我们来看一些与 COVID-19 相关的术语。
# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
corr = np.inner(features, features)
corr /= np.max(corr)
sns.heatmap(corr, xticklabels=labels, yticklabels=labels)
# Generate embeddings for some terms
queries = [
# Related viruses
'coronavirus', 'SARS', 'MERS',
# Regions
'Italy', 'Spain', 'Europe',
# Symptoms
'cough', 'fever', 'throat'
]
module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)
plot_correlation(queries, embeddings)
可以看到,嵌入向量成功捕获了不同术语的含义。每个单词都与其所在簇的其他单词相似(即“coronavirus”与“SARS”和“MERS”高度相关),但与其他簇的术语不同(即“SARS”与“Spain”之间的相似度接近于 0)。
现在,我们来看看如何使用这些嵌入向量解决特定任务。
SciCite:引用意图分类
本部分介绍了将嵌入向量用于下游任务(如文本分类)的方法。我们将使用 TensorFlow 数据集中的 SciCite 数据集对学术论文中的引文意图进行分类。给定一个带有学术论文引文的句子,对引文的主要意图进行分类:是背景信息、使用方法,还是比较结果。
builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
split=('train', 'validation', 'test'),
as_supervised=True)
Let's take a look at a few labeled examples from the training set
NUM_EXAMPLES = 10
TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]
def label2str(numeric_label):
m = builder.info.features[LABEL_NAME].names
return m[numeric_label]
data = next(iter(train_data.batch(NUM_EXAMPLES)))
pd.DataFrame({
TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
LABEL_NAME: [label2str(x) for x in data[1]]
})
训练引用意图分类器
我们将使用 Keras 在 SciCite 数据集上对分类器进行训练。我们构建一个模型,该模型使用 CORD-19 嵌入向量,并在顶部具有一个分类层。
Hyperparameters
EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3'
TRAINABLE_MODULE = False
hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[],
dtype=tf.string, trainable=TRAINABLE_MODULE)
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= keras_layer (KerasLayer) (None, 128) 17301632 dense (Dense) (None, 3) 387 ================================================================= Total params: 17,302,019 Trainable params: 387 Non-trainable params: 17,301,632 _________________________________________________________________
训练并评估模型
让我们训练并评估模型以查看在 SciCite 任务上的性能。
EPOCHS = 35
BATCH_SIZE = 32
history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
epochs=EPOCHS,
validation_data=validation_data.batch(BATCH_SIZE),
verbose=1)
Epoch 1/35 257/257 [==============================] - 2s 5ms/step - loss: 0.9287 - accuracy: 0.5848 - val_loss: 0.7893 - val_accuracy: 0.6769 Epoch 2/35 257/257 [==============================] - 2s 5ms/step - loss: 0.7066 - accuracy: 0.7103 - val_loss: 0.6824 - val_accuracy: 0.7260 Epoch 3/35 257/257 [==============================] - 2s 5ms/step - loss: 0.6291 - accuracy: 0.7514 - val_loss: 0.6319 - val_accuracy: 0.7478 Epoch 4/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5917 - accuracy: 0.7670 - val_loss: 0.6090 - val_accuracy: 0.7544 Epoch 5/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5711 - accuracy: 0.7783 - val_loss: 0.5943 - val_accuracy: 0.7587 Epoch 6/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5574 - accuracy: 0.7823 - val_loss: 0.5815 - val_accuracy: 0.7653 Epoch 7/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5476 - accuracy: 0.7846 - val_loss: 0.5744 - val_accuracy: 0.7729 Epoch 8/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5404 - accuracy: 0.7870 - val_loss: 0.5696 - val_accuracy: 0.7707 Epoch 9/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5358 - accuracy: 0.7896 - val_loss: 0.5677 - val_accuracy: 0.7697 Epoch 10/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5312 - accuracy: 0.7917 - val_loss: 0.5635 - val_accuracy: 0.7718 Epoch 11/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5268 - accuracy: 0.7919 - val_loss: 0.5617 - val_accuracy: 0.7718 Epoch 12/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5241 - accuracy: 0.7928 - val_loss: 0.5611 - val_accuracy: 0.7697 Epoch 13/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5217 - accuracy: 0.7939 - val_loss: 0.5575 - val_accuracy: 0.7718 Epoch 14/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5191 - accuracy: 0.7953 - val_loss: 0.5553 - val_accuracy: 0.7762 Epoch 15/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5169 - accuracy: 0.7957 - val_loss: 0.5532 - val_accuracy: 0.7806 Epoch 16/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5153 - accuracy: 0.7952 - val_loss: 0.5545 - val_accuracy: 0.7697 Epoch 17/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5138 - accuracy: 0.7962 - val_loss: 0.5518 - val_accuracy: 0.7828 Epoch 18/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5125 - accuracy: 0.7955 - val_loss: 0.5525 - val_accuracy: 0.7784 Epoch 19/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5110 - accuracy: 0.7975 - val_loss: 0.5516 - val_accuracy: 0.7762 Epoch 20/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5097 - accuracy: 0.7969 - val_loss: 0.5514 - val_accuracy: 0.7740 Epoch 21/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5090 - accuracy: 0.7969 - val_loss: 0.5491 - val_accuracy: 0.7762 Epoch 22/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5076 - accuracy: 0.7991 - val_loss: 0.5513 - val_accuracy: 0.7773 Epoch 23/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5067 - accuracy: 0.7985 - val_loss: 0.5488 - val_accuracy: 0.7806 Epoch 24/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5067 - accuracy: 0.7981 - val_loss: 0.5475 - val_accuracy: 0.7806 Epoch 25/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5048 - accuracy: 0.7995 - val_loss: 0.5472 - val_accuracy: 0.7795 Epoch 26/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5040 - accuracy: 0.8010 - val_loss: 0.5480 - val_accuracy: 0.7817 Epoch 27/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5036 - accuracy: 0.8006 - val_loss: 0.5454 - val_accuracy: 0.7838 Epoch 28/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5033 - accuracy: 0.8018 - val_loss: 0.5474 - val_accuracy: 0.7849 Epoch 29/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5025 - accuracy: 0.7996 - val_loss: 0.5459 - val_accuracy: 0.7838 Epoch 30/35 257/257 [==============================] - 1s 4ms/step - loss: 0.5018 - accuracy: 0.8013 - val_loss: 0.5468 - val_accuracy: 0.7871 Epoch 31/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5016 - accuracy: 0.7997 - val_loss: 0.5475 - val_accuracy: 0.7860 Epoch 32/35 257/257 [==============================] - 2s 5ms/step - loss: 0.5009 - accuracy: 0.8005 - val_loss: 0.5461 - val_accuracy: 0.7893 Epoch 33/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5006 - accuracy: 0.8017 - val_loss: 0.5459 - val_accuracy: 0.7882 Epoch 34/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5000 - accuracy: 0.8021 - val_loss: 0.5459 - val_accuracy: 0.7882 Epoch 35/35 257/257 [==============================] - 1s 5ms/step - loss: 0.4998 - accuracy: 0.8031 - val_loss: 0.5460 - val_accuracy: 0.7893
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
if subplot%10==1: # set up the subplots on the first call
plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
plt.tight_layout()
ax = plt.subplot(subplot)
ax.set_facecolor('#F8F8F8')
ax.plot(training)
ax.plot(validation)
ax.set_title('model '+ title)
ax.set_ylabel(title)
ax.set_xlabel('epoch')
ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)
评估模型
我们来看看模型的表现。模型将返回两个值:损失(表示错误的数字,值越低越好)和准确率。
results = model.evaluate(test_data.batch(512), verbose=2)
for name, value in zip(model.metrics_names, results):
print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5401 - accuracy: 0.7843 - 339ms/epoch - 85ms/step loss: 0.540 accuracy: 0.784
可以看到,损失迅速减小,而准确率迅速提高。我们绘制一些样本来检查预测与真实标签的关系:
prediction_dataset = next(iter(test_data.batch(20)))
prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]
predictions = [
label2str(x) for x in np.argmax(model.predict(prediction_texts), axis=-1)]
pd.DataFrame({
TEXT_FEATURE_NAME: prediction_texts,
LABEL_NAME: prediction_labels,
'prediction': predictions
})
1/1 [==============================] - 0s 125ms/step
可以看到,对于此随机样本,模型大多数时候都会预测正确的标签,这表明它可以很好地嵌入科学句子。
后续计划
现在,您已经对 TF-Hub 中的 CORD-19 Swivel 嵌入向量有了更多了解,我们鼓励您参加 CORD-19 Kaggle 竞赛,为从 COVID-19 相关学术文本中获得更深入的科学洞见做出贡献。
- 参加 CORD-19 Kaggle Challenge
- 详细了解 COVID-19 Open Research Dataset (CORD-19)
- 访问 https://tfhub.dev/tensorflow/cord-19/swivel-128d/3,参阅文档并详细了解 TF-Hub 嵌入向量
- 使用 TensorFlow Embedding Projector 探索 CORD-19 嵌入向量空间