Exploring the TF-Hub CORD-19 Swivel Embeddings

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook See TF Hub model

The CORD-19 Swivel text embedding module from TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/3) was built to support researchers analyzing natural languages text related to COVID-19. These embeddings were trained on the titles, authors, abstracts, body texts, and reference titles of articles in the CORD-19 dataset.

In this colab we will:

  • Analyze semantically similar words in the embedding space
  • Train a classifier on the SciCite dataset using the CORD-19 embeddings

Setup

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange
2021-07-29 12:25:53.619135: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0

Analyze the embeddings

Let's start off by analyzing the embedding by calculating and plotting a correlation matrix between different terms. If the embedding learned to successfully capture the meaning of different words, the embedding vectors of semantically similar words should be close together. Let's take a look at some COVID-19 related terms.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)
2021-07-29 12:25:59.080690: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-07-29 12:25:59.810158: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.811201: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-07-29 12:25:59.811244: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-07-29 12:25:59.814240: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-07-29 12:25:59.814353: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2021-07-29 12:25:59.815310: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcufft.so.10
2021-07-29 12:25:59.815702: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcurand.so.10
2021-07-29 12:25:59.816494: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11
2021-07-29 12:25:59.817238: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusparse.so.11
2021-07-29 12:25:59.817439: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2021-07-29 12:25:59.817548: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.818653: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.819625: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-07-29 12:25:59.820194: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-07-29 12:25:59.820726: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.821772: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-07-29 12:25:59.821868: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.822831: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.823718: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-07-29 12:25:59.823793: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-07-29 12:26:00.432478: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-07-29 12:26:00.432518: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264]      0 
2021-07-29 12:26:00.432527: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0:   N 
2021-07-29 12:26:00.432785: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:26:00.433764: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:26:00.434667: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:26:00.435561: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 14646 MB memory) -> physical GPU (device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0)
2021-07-29 12:26:00.770580: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-07-29 12:26:00.771193: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2000179999 Hz

png

We can see that the embedding successfully captured the meaning of the different terms. Each word is similar to the other words of its cluster (i.e. "coronavirus" highly correlates with "SARS" and "MERS"), while they are different from terms of other clusters (i.e. the similarity between "SARS" and "Spain" is close to 0).

Now let's see how we can use these embeddings to solve a specific task.

SciCite: Citation Intent Classification

This section shows how one can use the embedding for downstream tasks such as text classification. We'll use the SciCite dataset from TensorFlow Datasets to classify citation intents in academic papers. Given a sentence with a citation from an academic paper, classify whether the main intent of the citation is as background information, use of methods, or comparing results.

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

Let's take a look at a few labeled examples from the training set

NUM_EXAMPLES =   10

TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]

def label2str(numeric_label):
  m = builder.info.features[LABEL_NAME].names
  return m[numeric_label]

data = next(iter(train_data.batch(NUM_EXAMPLES)))


pd.DataFrame({
    TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
    LABEL_NAME: [label2str(x) for x in data[1]]
})

Training a citaton intent classifier

We'll train a classifier on the SciCite dataset using Keras. Let's build a model which use the CORD-19 embeddings with a classification layer on top.

Hyperparameters

EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3' 
TRAINABLE_MODULE = False 

hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[], 
                           dtype=tf.string, trainable=TRAINABLE_MODULE)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 128)               17301632  
_________________________________________________________________
dense (Dense)                (None, 3)                 387       
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

Train and evaluate the model

Let's train and evaluate the model to see the performance on the SciCite task

EPOCHS =   35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
2021-07-29 12:26:08.694142: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
42/257 [===>..........................] - ETA: 0s - loss: 1.1140 - accuracy: 0.4077
2021-07-29 12:26:09.103023: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
257/257 [==============================] - 3s 8ms/step - loss: 0.9001 - accuracy: 0.5791 - val_loss: 0.7704 - val_accuracy: 0.6714
Epoch 2/35
257/257 [==============================] - 2s 6ms/step - loss: 0.6927 - accuracy: 0.7199 - val_loss: 0.6645 - val_accuracy: 0.7489
Epoch 3/35
257/257 [==============================] - 3s 7ms/step - loss: 0.6213 - accuracy: 0.7625 - val_loss: 0.6254 - val_accuracy: 0.7555
Epoch 4/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5876 - accuracy: 0.7761 - val_loss: 0.6025 - val_accuracy: 0.7642
Epoch 5/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5677 - accuracy: 0.7826 - val_loss: 0.5875 - val_accuracy: 0.7642
Epoch 6/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5547 - accuracy: 0.7851 - val_loss: 0.5776 - val_accuracy: 0.7740
Epoch 7/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5457 - accuracy: 0.7867 - val_loss: 0.5709 - val_accuracy: 0.7817
Epoch 8/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5393 - accuracy: 0.7906 - val_loss: 0.5669 - val_accuracy: 0.7784
Epoch 9/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5341 - accuracy: 0.7914 - val_loss: 0.5649 - val_accuracy: 0.7828
Epoch 10/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5293 - accuracy: 0.7911 - val_loss: 0.5597 - val_accuracy: 0.7860
Epoch 11/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5259 - accuracy: 0.7922 - val_loss: 0.5569 - val_accuracy: 0.7828
Epoch 12/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5230 - accuracy: 0.7948 - val_loss: 0.5559 - val_accuracy: 0.7806
Epoch 13/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5200 - accuracy: 0.7940 - val_loss: 0.5539 - val_accuracy: 0.7795
Epoch 14/35
257/257 [==============================] - 2s 7ms/step - loss: 0.5179 - accuracy: 0.7957 - val_loss: 0.5521 - val_accuracy: 0.7806
Epoch 15/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5160 - accuracy: 0.7952 - val_loss: 0.5530 - val_accuracy: 0.7838
Epoch 16/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5146 - accuracy: 0.7961 - val_loss: 0.5525 - val_accuracy: 0.7740
Epoch 17/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5128 - accuracy: 0.7975 - val_loss: 0.5508 - val_accuracy: 0.7828
Epoch 18/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5116 - accuracy: 0.7974 - val_loss: 0.5504 - val_accuracy: 0.7773
Epoch 19/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5101 - accuracy: 0.7983 - val_loss: 0.5502 - val_accuracy: 0.7784
Epoch 20/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5095 - accuracy: 0.7972 - val_loss: 0.5477 - val_accuracy: 0.7838
Epoch 21/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5079 - accuracy: 0.7988 - val_loss: 0.5476 - val_accuracy: 0.7817
Epoch 22/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5069 - accuracy: 0.8003 - val_loss: 0.5476 - val_accuracy: 0.7828
Epoch 23/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5060 - accuracy: 0.7990 - val_loss: 0.5495 - val_accuracy: 0.7806
Epoch 24/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5052 - accuracy: 0.7999 - val_loss: 0.5479 - val_accuracy: 0.7828
Epoch 25/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5047 - accuracy: 0.7986 - val_loss: 0.5473 - val_accuracy: 0.7828
Epoch 26/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5035 - accuracy: 0.8006 - val_loss: 0.5464 - val_accuracy: 0.7849
Epoch 27/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5032 - accuracy: 0.7990 - val_loss: 0.5461 - val_accuracy: 0.7849
Epoch 28/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5029 - accuracy: 0.8013 - val_loss: 0.5460 - val_accuracy: 0.7860
Epoch 29/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5019 - accuracy: 0.8003 - val_loss: 0.5436 - val_accuracy: 0.7882
Epoch 30/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5022 - accuracy: 0.8006 - val_loss: 0.5447 - val_accuracy: 0.7882
Epoch 31/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5009 - accuracy: 0.8010 - val_loss: 0.5475 - val_accuracy: 0.7893
Epoch 32/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5005 - accuracy: 0.8035 - val_loss: 0.5477 - val_accuracy: 0.7860
Epoch 33/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5002 - accuracy: 0.8021 - val_loss: 0.5433 - val_accuracy: 0.7893
Epoch 34/35
257/257 [==============================] - 2s 5ms/step - loss: 0.4998 - accuracy: 0.7996 - val_loss: 0.5473 - val_accuracy: 0.7882
Epoch 35/35
257/257 [==============================] - 2s 5ms/step - loss: 0.4993 - accuracy: 0.8012 - val_loss: 0.5449 - val_accuracy: 0.7904
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

png

Evaluate the model

And let's see how the model performs. Two values will be returned. Loss (a number which represents our error, lower values are better), and accuracy.

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5342 - accuracy: 0.7929
loss: 0.534
accuracy: 0.793

We can see that the loss quickly decreases while especially the accuracy rapidly increases. Let's plot some examples to check how the prediction relates to the true labels:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [label2str(x) for x in model.predict_classes(prediction_texts)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py:455: UserWarning: `model.predict_classes()` is deprecated and will be removed after 2021-01-01. Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).
  warnings.warn('`model.predict_classes()` is deprecated and '

We can see that for this random sample, the model predicts the correct label most of the times, indicating that it can embed scientific sentences pretty well.

What's next?

Now that you've gotten to know a bit more about the CORD-19 Swivel embeddings from TF-Hub, we encourage you to participate in the CORD-19 Kaggle competition to contribute to gaining scientific insights from COVID-19 related academic texts.