![]() |
![]() |
![]() |
![]() |
![]() |
The CORD-19 Swivel text embedding module from TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/3) was built to support researchers analyzing natural languages text related to COVID-19. These embeddings were trained on the titles, authors, abstracts, body texts, and reference titles of articles in the CORD-19 dataset.
In this colab we will:
- Analyze semantically similar words in the embedding space
- Train a classifier on the SciCite dataset using the CORD-19 embeddings
Setup
import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tqdm import trange
Analyze the embeddings
Let's start off by analyzing the embedding by calculating and plotting a correlation matrix between different terms. If the embedding learned to successfully capture the meaning of different words, the embedding vectors of semantically similar words should be close together. Let's take a look at some COVID-19 related terms.
# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
corr = np.inner(features, features)
corr /= np.max(corr)
sns.heatmap(corr, xticklabels=labels, yticklabels=labels)
# Generate embeddings for some terms
queries = [
# Related viruses
'coronavirus', 'SARS', 'MERS',
# Regions
'Italy', 'Spain', 'Europe',
# Symptoms
'cough', 'fever', 'throat'
]
module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)
plot_correlation(queries, embeddings)
We can see that the embedding successfully captured the meaning of the different terms. Each word is similar to the other words of its cluster (i.e. "coronavirus" highly correlates with "SARS" and "MERS"), while they are different from terms of other clusters (i.e. the similarity between "SARS" and "Spain" is close to 0).
Now let's see how we can use these embeddings to solve a specific task.
SciCite: Citation Intent Classification
This section shows how one can use the embedding for downstream tasks such as text classification. We'll use the SciCite dataset from TensorFlow Datasets to classify citation intents in academic papers. Given a sentence with a citation from an academic paper, classify whether the main intent of the citation is as background information, use of methods, or comparing results.
builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
split=('train', 'validation', 'test'),
as_supervised=True)
Let's take a look at a few labeled examples from the training set
NUM_EXAMPLES = 10
TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]
def label2str(numeric_label):
m = builder.info.features[LABEL_NAME].names
return m[numeric_label]
data = next(iter(train_data.batch(NUM_EXAMPLES)))
pd.DataFrame({
TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
LABEL_NAME: [label2str(x) for x in data[1]]
})
Training a citaton intent classifier
We'll train a classifier on the SciCite dataset using Keras. Let's build a model which use the CORD-19 embeddings with a classification layer on top.
Hyperparameters
EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3'
TRAINABLE_MODULE = False
hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[],
dtype=tf.string, trainable=TRAINABLE_MODULE)
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7fe7fc16cf28> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details. WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7fe7fc16cf28> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details. WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because its dtype defaults to floatx. If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2. To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor. WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because its dtype defaults to floatx. If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2. To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor. Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= keras_layer (KerasLayer) (None, 128) 17301632 _________________________________________________________________ dense (Dense) (None, 3) 387 ================================================================= Total params: 17,302,019 Trainable params: 387 Non-trainable params: 17,301,632 _________________________________________________________________
Train and evaluate the model
Let's train and evaluate the model to see the performance on the SciCite task
EPOCHS = 35
BATCH_SIZE = 32
history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
epochs=EPOCHS,
validation_data=validation_data.batch(BATCH_SIZE),
verbose=1)
Epoch 1/35 257/257 [==============================] - 1s 5ms/step - loss: 0.8978 - accuracy: 0.6190 - val_loss: 0.7777 - val_accuracy: 0.6736 Epoch 2/35 257/257 [==============================] - 1s 5ms/step - loss: 0.6920 - accuracy: 0.7241 - val_loss: 0.6764 - val_accuracy: 0.7227 Epoch 3/35 257/257 [==============================] - 1s 5ms/step - loss: 0.6215 - accuracy: 0.7588 - val_loss: 0.6296 - val_accuracy: 0.7434 Epoch 4/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5867 - accuracy: 0.7757 - val_loss: 0.6071 - val_accuracy: 0.7489 Epoch 5/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5674 - accuracy: 0.7808 - val_loss: 0.5923 - val_accuracy: 0.7576 Epoch 6/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5549 - accuracy: 0.7874 - val_loss: 0.5818 - val_accuracy: 0.7664 Epoch 7/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5457 - accuracy: 0.7890 - val_loss: 0.5756 - val_accuracy: 0.7664 Epoch 8/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5392 - accuracy: 0.7883 - val_loss: 0.5698 - val_accuracy: 0.7740 Epoch 9/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5335 - accuracy: 0.7907 - val_loss: 0.5643 - val_accuracy: 0.7707 Epoch 10/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5294 - accuracy: 0.7934 - val_loss: 0.5609 - val_accuracy: 0.7751 Epoch 11/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5260 - accuracy: 0.7933 - val_loss: 0.5591 - val_accuracy: 0.7762 Epoch 12/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5231 - accuracy: 0.7929 - val_loss: 0.5578 - val_accuracy: 0.7773 Epoch 13/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5202 - accuracy: 0.7933 - val_loss: 0.5586 - val_accuracy: 0.7751 Epoch 14/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5181 - accuracy: 0.7938 - val_loss: 0.5536 - val_accuracy: 0.7806 Epoch 15/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5164 - accuracy: 0.7944 - val_loss: 0.5514 - val_accuracy: 0.7806 Epoch 16/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5144 - accuracy: 0.7961 - val_loss: 0.5513 - val_accuracy: 0.7751 Epoch 17/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5128 - accuracy: 0.7967 - val_loss: 0.5526 - val_accuracy: 0.7773 Epoch 18/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5113 - accuracy: 0.7974 - val_loss: 0.5524 - val_accuracy: 0.7784 Epoch 19/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5104 - accuracy: 0.7972 - val_loss: 0.5480 - val_accuracy: 0.7817 Epoch 20/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5094 - accuracy: 0.7977 - val_loss: 0.5479 - val_accuracy: 0.7817 Epoch 21/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5077 - accuracy: 0.7981 - val_loss: 0.5480 - val_accuracy: 0.7838 Epoch 22/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5072 - accuracy: 0.7983 - val_loss: 0.5450 - val_accuracy: 0.7817 Epoch 23/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5059 - accuracy: 0.7994 - val_loss: 0.5495 - val_accuracy: 0.7795 Epoch 24/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5056 - accuracy: 0.7991 - val_loss: 0.5471 - val_accuracy: 0.7817 Epoch 25/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5048 - accuracy: 0.7983 - val_loss: 0.5474 - val_accuracy: 0.7806 Epoch 26/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5042 - accuracy: 0.8000 - val_loss: 0.5481 - val_accuracy: 0.7828 Epoch 27/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5031 - accuracy: 0.7996 - val_loss: 0.5447 - val_accuracy: 0.7882 Epoch 28/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5027 - accuracy: 0.8007 - val_loss: 0.5440 - val_accuracy: 0.7860 Epoch 29/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5018 - accuracy: 0.8024 - val_loss: 0.5460 - val_accuracy: 0.7871 Epoch 30/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5017 - accuracy: 0.8002 - val_loss: 0.5454 - val_accuracy: 0.7893 Epoch 31/35 257/257 [==============================] - 1s 4ms/step - loss: 0.5013 - accuracy: 0.8008 - val_loss: 0.5456 - val_accuracy: 0.7828 Epoch 32/35 257/257 [==============================] - 1s 5ms/step - loss: 0.5003 - accuracy: 0.8024 - val_loss: 0.5467 - val_accuracy: 0.7871 Epoch 33/35 257/257 [==============================] - 1s 4ms/step - loss: 0.4999 - accuracy: 0.8007 - val_loss: 0.5445 - val_accuracy: 0.7915 Epoch 34/35 257/257 [==============================] - 1s 4ms/step - loss: 0.4999 - accuracy: 0.8025 - val_loss: 0.5445 - val_accuracy: 0.7893 Epoch 35/35 257/257 [==============================] - 1s 4ms/step - loss: 0.4995 - accuracy: 0.8010 - val_loss: 0.5453 - val_accuracy: 0.7828
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
if subplot%10==1: # set up the subplots on the first call
plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
plt.tight_layout()
ax = plt.subplot(subplot)
ax.set_facecolor('#F8F8F8')
ax.plot(training)
ax.plot(validation)
ax.set_title('model '+ title)
ax.set_ylabel(title)
ax.set_xlabel('epoch')
ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)
Evaluate the model
And let's see how the model performs. Two values will be returned. Loss (a number which represents our error, lower values are better), and accuracy.
results = model.evaluate(test_data.batch(512), verbose=2)
for name, value in zip(model.metrics_names, results):
print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5357 - accuracy: 0.7918 loss: 0.536 accuracy: 0.792
We can see that the loss quickly decreases while especially the accuracy rapidly increases. Let's plot some examples to check how the prediction relates to the true labels:
prediction_dataset = next(iter(test_data.batch(20)))
prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]
predictions = [label2str(x) for x in model.predict_classes(prediction_texts)]
pd.DataFrame({
TEXT_FEATURE_NAME: prediction_texts,
LABEL_NAME: prediction_labels,
'prediction': predictions
})
WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01. Instructions for updating: Please use instead:* `np.argmax(model.predict(x), axis=-1)`, if your model does multi-class classification (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`, if your model does binary classification (e.g. if it uses a `sigmoid` last-layer activation). WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01. Instructions for updating: Please use instead:* `np.argmax(model.predict(x), axis=-1)`, if your model does multi-class classification (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`, if your model does binary classification (e.g. if it uses a `sigmoid` last-layer activation).
We can see that for this random sample, the model predicts the correct label most of the times, indicating that it can embed scientific sentences pretty well.
What's next?
Now that you've gotten to know a bit more about the CORD-19 Swivel embeddings from TF-Hub, we encourage you to participate in the CORD-19 Kaggle competition to contribute to gaining scientific insights from COVID-19 related academic texts.
- Participate in the CORD-19 Kaggle Challenge
- Learn more about the COVID-19 Open Research Dataset (CORD-19)
- See documentation and more about the TF-Hub embeddings at https://tfhub.dev/tensorflow/cord-19/swivel-128d/3
- Explore the CORD-19 embedding space with the TensorFlow Embedding Projector