Exploring the TF-Hub CORD-19 Swivel Embeddings

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook See TF Hub model

The CORD-19 Swivel text embedding module from TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/1) was built to support researchers analyzing natural languages text related to COVID-19. These embeddings were trained on the titles, authors, abstracts, body texts, and reference titles of articles in the CORD-19 dataset.

In this colab we will:

  • Analyze semantically similar words in the embedding space
  • Train a classifier on the SciCite dataset using the CORD-19 embeddings


import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow.compat.v1 as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

  from google.colab import data_table
  def display_df(df):
    return data_table.DataTable(df, include_index=False)
except ModuleNotFoundError:
  # If google-colab is not available, just display the raw DataFrame
  def display_df(df):
    return df

Analyze the embeddings

Let's start off by analyzing the embedding by calculating and plotting a correlation matrix between different terms. If the embedding learned to successfully capture the meaning of different words, the embedding vectors of semantically similar words should be close together. Let's take a look at some COVID-19 related terms.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

with tf.Graph().as_default():
  # Load the module
  query_input = tf.placeholder(tf.string)
  module = hub.Module('https://tfhub.dev/tensorflow/cord-19/swivel-128d/1')
  embeddings = module(query_input)

  with tf.train.MonitoredTrainingSession() as sess:

    # Generate embeddings for some terms
    queries = [
        # Related viruses
        "coronavirus", "SARS", "MERS",
        # Regions
        "Italy", "Spain", "Europe",
        # Symptoms
        "cough", "fever", "throat"

    features = sess.run(embeddings, feed_dict={query_input: queries})
    plot_correlation(queries, features)


We can see that the embedding successfully captured the meaning of the different terms. Each word is similar to the other words of its cluster (i.e. "coronavirus" highly correlates with "SARS" and "MERS"), while they are different from terms of other clusters (i.e. the similarity between "SARS" and "Spain" is close to 0).

Now let's see how we can use these embeddings to solve a specific task.

SciCite: Citation Intent Classification

This section shows how one can use the embedding for downstream tasks such as text classification. We'll use the SciCite dataset from TensorFlow Datasets to classify citation intents in academic papers. Given a sentence with a citation from an academic paper, classify whether the main intent of the citation is as background information, use of methods, or comparing results.

class Dataset:
  """Build a dataset from a TFDS dataset."""
  def __init__(self, tfds_name, feature_name, label_name):
    self.dataset_builder = tfds.builder(tfds_name)
    self.feature_name = feature_name
    self.label_name = label_name
  def get_data(self, for_eval):
    splits = THE_DATASET.dataset_builder.info.splits
    if tfds.Split.TEST in splits:
      split = tfds.Split.TEST if for_eval else tfds.Split.TRAIN
      SPLIT_PERCENT = 80
      split = "train[{}%:]".format(SPLIT_PERCENT) if for_eval else "train[:{}%]".format(SPLIT_PERCENT)
    return self.dataset_builder.as_dataset(split=split)

  def num_classes(self):
    return self.dataset_builder.info.features[self.label_name].num_classes

  def class_names(self):
    return self.dataset_builder.info.features[self.label_name].names

  def preprocess_fn(self, data):
    return data[self.feature_name], data[self.label_name]

  def example_fn(self, data):
    feature, label = self.preprocess_fn(data)
    return {'feature': feature, 'label': label}, label

def get_example_data(dataset, num_examples, **data_kw):
  """Show example data"""
  with tf.Session() as sess:
    batched_ds = dataset.get_data(**data_kw).take(num_examples).map(dataset.preprocess_fn).batch(num_examples)
    it = tf.data.make_one_shot_iterator(batched_ds).get_next()
    data = sess.run(it)
  return data

TFDS_NAME = 'scicite' 
LABEL_NAME = 'label' 
Downloading and preparing dataset scicite/1.0.0 (download: 22.12 MiB, generated: Unknown size, total: 22.12 MiB) to /home/kbuilder/tensorflow_datasets/scicite/1.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incomplete5C35GW/scicite-train.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incomplete5C35GW/scicite-validation.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incomplete5C35GW/scicite-test.tfrecord
Dataset scicite downloaded and prepared to /home/kbuilder/tensorflow_datasets/scicite/1.0.0. Subsequent calls will reuse this data.

data = get_example_data(THE_DATASET, NUM_EXAMPLES, for_eval=False)
        TEXT_FEATURE_NAME: [ex.decode('utf8') for ex in data[0]],
        LABEL_NAME: [THE_DATASET.class_names()[x] for x in data[1]]

Training a citaton intent classifier

We'll train a classifier on the SciCite dataset using an Estimator. Let's set up the input_fns to read the dataset into the model

def preprocessed_input_fn(for_eval):
  data = THE_DATASET.get_data(for_eval=for_eval)
  data = data.map(THE_DATASET.example_fn, num_parallel_calls=1)
  return data

def input_fn_train(params):
  data = preprocessed_input_fn(for_eval=False)
  data = data.repeat(None)
  data = data.shuffle(1024)
  data = data.batch(batch_size=params['batch_size'])
  return data

def input_fn_eval(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.repeat(1)
  data = data.batch(batch_size=params['batch_size'])
  return data

def input_fn_predict(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.batch(batch_size=params['batch_size'])
  return data

Let's build a model which use the CORD-19 embeddings with a classification layer on top.

def model_fn(features, labels, mode, params):
  # Embed the text
  embed = hub.Module(params['module_name'], trainable=params['trainable_module'])
  embeddings = embed(features['feature'])

  # Add a linear layer on top
  logits = tf.layers.dense(
      embeddings, units=THE_DATASET.num_classes(), activation=None)
  predictions = tf.argmax(input=logits, axis=1)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(
            'logits': logits,
            'predictions': predictions,
            'features': features['feature'],
            'labels': features['label']
  # Set up a multi-class classification head
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits)
  loss = tf.reduce_mean(loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=params['learning_rate'])
    train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

  elif mode == tf.estimator.ModeKeys.EVAL:
    accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
    precision = tf.metrics.precision(labels=labels, predictions=predictions)
    recall = tf.metrics.recall(labels=labels, predictions=predictions)

    return tf.estimator.EstimatorSpec(
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,

EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/1'  
STEPS =   8000
EVAL_EVERY = 200  

params = {
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'module_name': EMBEDDING,
    'trainable_module': TRAINABLE_MODULE

Train and evaluate the model

Let's train and evaluate the model to see the performance on the SciCite task

estimator = tf.estimator.Estimator(functools.partial(model_fn, params=params))
metrics = []

for step in range(0, STEPS, EVAL_EVERY):
  estimator.train(input_fn=functools.partial(input_fn_train, params=params), steps=EVAL_EVERY)
  step_metrics = estimator.evaluate(input_fn=functools.partial(input_fn_eval, params=params))
  print('Global step {}: loss {:.3f}, accuracy {:.3f}'.format(step, step_metrics['loss'], step_metrics['accuracy']))
Global step 0: loss 0.854, accuracy 0.621
Global step 200: loss 0.740, accuracy 0.691
Global step 400: loss 0.687, accuracy 0.733
Global step 600: loss 0.649, accuracy 0.748
Global step 800: loss 0.648, accuracy 0.733
Global step 1000: loss 0.617, accuracy 0.772
Global step 1200: loss 0.610, accuracy 0.757
Global step 1400: loss 0.591, accuracy 0.779
Global step 1600: loss 0.589, accuracy 0.777
Global step 1800: loss 0.580, accuracy 0.777
Global step 2000: loss 0.589, accuracy 0.764
Global step 2200: loss 0.567, accuracy 0.785
Global step 2400: loss 0.572, accuracy 0.774
Global step 2600: loss 0.571, accuracy 0.774
Global step 2800: loss 0.566, accuracy 0.772
Global step 3000: loss 0.568, accuracy 0.770
Global step 3200: loss 0.562, accuracy 0.777
Global step 3400: loss 0.559, accuracy 0.779
Global step 3600: loss 0.554, accuracy 0.777
Global step 3800: loss 0.547, accuracy 0.786
Global step 4000: loss 0.559, accuracy 0.774
Global step 4200: loss 0.550, accuracy 0.783
Global step 4400: loss 0.563, accuracy 0.771
Global step 4600: loss 0.547, accuracy 0.782
Global step 4800: loss 0.542, accuracy 0.786
Global step 5000: loss 0.547, accuracy 0.777
Global step 5200: loss 0.545, accuracy 0.789
Global step 5400: loss 0.545, accuracy 0.785
Global step 5600: loss 0.543, accuracy 0.782
Global step 5800: loss 0.545, accuracy 0.782
Global step 6000: loss 0.539, accuracy 0.784
Global step 6200: loss 0.542, accuracy 0.790
Global step 6400: loss 0.542, accuracy 0.784
Global step 6600: loss 0.548, accuracy 0.781
Global step 6800: loss 0.543, accuracy 0.785
Global step 7000: loss 0.541, accuracy 0.785
Global step 7200: loss 0.541, accuracy 0.778
Global step 7400: loss 0.535, accuracy 0.789
Global step 7600: loss 0.556, accuracy 0.774
Global step 7800: loss 0.544, accuracy 0.783

global_steps = [x['global_step'] for x in metrics]
fig, axes = plt.subplots(ncols=2, figsize=(20,8))

for axes_index, metric_names in enumerate([['accuracy', 'precision', 'recall'],
  for metric_name in metric_names:
    axes[axes_index].plot(global_steps, [x[metric_name] for x in metrics], label=metric_name)
  axes[axes_index].set_xlabel("Global Step")


We can see that the loss quickly decreases while especially the accuracy rapidly increases. Let's plot some examples to check how the prediction relates to the true labels:

predictions = estimator.predict(functools.partial(input_fn_predict, params))
first_10_predictions = list(itertools.islice(predictions, 10))

      TEXT_FEATURE_NAME: [pred['features'].decode('utf8') for pred in first_10_predictions],
      LABEL_NAME: [THE_DATASET.class_names()[pred['labels']] for pred in first_10_predictions],
      'prediction': [THE_DATASET.class_names()[pred['predictions']] for pred in first_10_predictions]

We can see that for this random sample, the model predicts the correct label most of the times, indicating that it can embed scientific sentences pretty well.

What's next?

Now that you've gotten to know a bit more about the CORD-19 Swivel embeddings from TF-Hub, we encourage you to participate in the CORD-19 Kaggle competition to contribute to gaining scientific insights from COVID-19 related academic texts.