探索 TF-Hub CORD-19 Swivel 嵌入向量

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook 查看 TF Hub 模型

TF-Hub 上的 CORD-19 Swivel 文本嵌入向量模块 (https://tfhub.dev/tensorflow/cord-19/swivel-128d/1) 旨在支持研究员分析与 COVID-19 相关的自然语言文本。这些嵌入向量针对 CORD-19 数据集中文章的标题、作者、摘要、正文文本和参考文献标题进行了训练。

在此 Colab 中,我们将进行以下操作:

  • 分析嵌入向量空间中语义相似的单词
  • 使用 CORD-19 嵌入向量在 SciCite 数据集上训练分类器

设置

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.logging.set_verbosity('ERROR')

import tensorflow_datasets as tfds
import tensorflow_hub as hub

try:
  from google.colab import data_table
  def display_df(df):
    return data_table.DataTable(df, include_index=False)
except ModuleNotFoundError:
  # If google-colab is not available, just display the raw DataFrame
  def display_df(df):
    return df

分析嵌入向量

首先,我们通过计算和绘制不同术语之间的相关矩阵来分析嵌入向量。如果嵌入向量学会了成功捕获不同单词的含义,则语义相似的单词的嵌入向量应相互靠近。我们来看一些与 COVID-19 相关的术语。

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)


with tf.Graph().as_default():
  # Load the module
  query_input = tf.placeholder(tf.string)
  module = hub.Module('https://tfhub.dev/tensorflow/cord-19/swivel-128d/1')
  embeddings = module(query_input)

  with tf.train.MonitoredTrainingSession() as sess:

    # Generate embeddings for some terms
    queries = [
        # Related viruses
        "coronavirus", "SARS", "MERS",
        # Regions
        "Italy", "Spain", "Europe",
        # Symptoms
        "cough", "fever", "throat"
    ]

    features = sess.run(embeddings, feed_dict={query_input: queries})
    plot_correlation(queries, features)
2022-06-03 15:57:19.859817: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.

png

可以看到,嵌入向量成功捕获了不同术语的含义。每个单词都与其所在簇的其他单词相似(即“coronavirus”与“SARS”和“MERS”高度相关),但与其他簇的术语不同(即“SARS”与“Spain”之间的相似度接近于 0)。

现在,我们来看看如何使用这些嵌入向量解决特定任务。

SciCite:引用意图分类

本部分介绍了将嵌入向量用于下游任务(如文本分类)的方法。我们将使用 TensorFlow 数据集中的 SciCite 数据集对学术论文中的引文意图进行分类。给定一个带有学术论文引文的句子,对引用的主要意图进行分类:是背景信息、使用方法,还是比较结果。

Set up the dataset from TFDS

Let's take a look at a few labeled examples from the training set

训练引用意图分类器

我们将使用 Estimator 在 SciCite 数据集上对分类器进行训练。让我们设置 input_fns,将数据集读取到模型中。

def preprocessed_input_fn(for_eval):
  data = THE_DATASET.get_data(for_eval=for_eval)
  data = data.map(THE_DATASET.example_fn, num_parallel_calls=1)
  return data


def input_fn_train(params):
  data = preprocessed_input_fn(for_eval=False)
  data = data.repeat(None)
  data = data.shuffle(1024)
  data = data.batch(batch_size=params['batch_size'])
  return data


def input_fn_eval(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.repeat(1)
  data = data.batch(batch_size=params['batch_size'])
  return data


def input_fn_predict(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.batch(batch_size=params['batch_size'])
  return data

我们构建一个模型,该模型使用 CORD-19 嵌入向量,并在顶部具有一个分类层。

def model_fn(features, labels, mode, params):
  # Embed the text
  embed = hub.Module(params['module_name'], trainable=params['trainable_module'])
  embeddings = embed(features['feature'])

  # Add a linear layer on top
  logits = tf.layers.dense(
      embeddings, units=THE_DATASET.num_classes(), activation=None)
  predictions = tf.argmax(input=logits, axis=1)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions={
            'logits': logits,
            'predictions': predictions,
            'features': features['feature'],
            'labels': features['label']
        })

  # Set up a multi-class classification head
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits)
  loss = tf.reduce_mean(loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=params['learning_rate'])
    train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

  elif mode == tf.estimator.ModeKeys.EVAL:
    accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
    precision = tf.metrics.precision(labels=labels, predictions=predictions)
    recall = tf.metrics.recall(labels=labels, predictions=predictions)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        eval_metric_ops={
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
        })

Hyperparmeters

训练并评估模型

让我们训练并评估模型以查看在 SciCite 任务上的性能。

estimator = tf.estimator.Estimator(functools.partial(model_fn, params=params))
metrics = []

for step in range(0, STEPS, EVAL_EVERY):
  estimator.train(input_fn=functools.partial(input_fn_train, params=params), steps=EVAL_EVERY)
  step_metrics = estimator.evaluate(input_fn=functools.partial(input_fn_eval, params=params))
  print('Global step {}: loss {:.3f}, accuracy {:.3f}'.format(step, step_metrics['loss'], step_metrics['accuracy']))
  metrics.append(step_metrics)
2022-06-03 15:57:24.599718: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
/tmpfs/tmp/ipykernel_7888/393120678.py:7: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  logits = tf.layers.dense(
2022-06-03 15:57:27.292257: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 0: loss 0.757, accuracy 0.693
2022-06-03 15:57:28.691465: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:30.098430: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 200: loss 0.692, accuracy 0.728
2022-06-03 15:57:31.097252: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:32.448996: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 400: loss 0.648, accuracy 0.755
2022-06-03 15:57:33.518678: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:34.888730: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 600: loss 0.630, accuracy 0.760
2022-06-03 15:57:35.894770: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:37.323809: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 800: loss 0.615, accuracy 0.761
2022-06-03 15:57:38.395212: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:39.742995: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 1000: loss 0.602, accuracy 0.772
2022-06-03 15:57:40.768911: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:42.217078: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 1200: loss 0.584, accuracy 0.775
2022-06-03 15:57:43.260860: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:44.769598: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 1400: loss 0.579, accuracy 0.776
2022-06-03 15:57:45.765620: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:47.089638: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 1600: loss 0.572, accuracy 0.777
2022-06-03 15:57:48.119079: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:49.510924: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 1800: loss 0.572, accuracy 0.777
2022-06-03 15:57:50.535905: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:51.872578: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 2000: loss 0.562, accuracy 0.786
2022-06-03 15:57:52.856432: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:54.179289: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 2200: loss 0.561, accuracy 0.783
2022-06-03 15:57:55.209365: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:56.570057: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 2400: loss 0.562, accuracy 0.781
2022-06-03 15:57:57.581781: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:57:58.957478: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 2600: loss 0.561, accuracy 0.777
2022-06-03 15:57:59.938241: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:01.418487: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 2800: loss 0.555, accuracy 0.785
2022-06-03 15:58:02.504248: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:03.837227: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 3000: loss 0.554, accuracy 0.788
2022-06-03 15:58:04.857276: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:06.205629: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 3200: loss 0.552, accuracy 0.791
2022-06-03 15:58:07.262203: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:08.671257: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 3400: loss 0.555, accuracy 0.778
2022-06-03 15:58:09.899864: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:11.247664: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 3600: loss 0.547, accuracy 0.786
2022-06-03 15:58:12.317154: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:13.695433: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 3800: loss 0.549, accuracy 0.780
2022-06-03 15:58:14.698366: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:16.026056: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 4000: loss 0.540, accuracy 0.792
2022-06-03 15:58:17.097506: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:18.473474: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 4200: loss 0.543, accuracy 0.788
2022-06-03 15:58:19.655576: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:20.971677: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 4400: loss 0.551, accuracy 0.784
2022-06-03 15:58:22.111667: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:23.560839: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 4600: loss 0.543, accuracy 0.794
2022-06-03 15:58:24.606913: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:25.973537: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 4800: loss 0.556, accuracy 0.771
2022-06-03 15:58:26.985245: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:28.332203: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 5000: loss 0.553, accuracy 0.779
2022-06-03 15:58:29.318939: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:30.656186: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 5200: loss 0.548, accuracy 0.780
2022-06-03 15:58:31.692607: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:33.053518: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 5400: loss 0.553, accuracy 0.774
2022-06-03 15:58:34.034002: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:35.397200: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 5600: loss 0.541, accuracy 0.788
2022-06-03 15:58:36.713981: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:38.097827: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 5800: loss 0.539, accuracy 0.790
2022-06-03 15:58:39.089203: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:40.442520: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 6000: loss 0.532, accuracy 0.799
2022-06-03 15:58:41.493169: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:42.827953: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 6200: loss 0.535, accuracy 0.791
2022-06-03 15:58:43.818725: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:45.162272: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 6400: loss 0.535, accuracy 0.790
2022-06-03 15:58:46.154605: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:47.483783: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 6600: loss 0.539, accuracy 0.791
2022-06-03 15:58:48.468785: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:49.815873: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 6800: loss 0.530, accuracy 0.798
2022-06-03 15:58:50.794918: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:52.183922: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 7000: loss 0.532, accuracy 0.797
2022-06-03 15:58:53.248291: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:54.600848: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 7200: loss 0.536, accuracy 0.782
2022-06-03 15:58:55.596162: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:56.923534: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 7400: loss 0.531, accuracy 0.798
2022-06-03 15:58:57.943252: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:58:59.361269: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 7600: loss 0.530, accuracy 0.795
2022-06-03 15:59:00.351200: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
2022-06-03 15:59:01.726499: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
Global step 7800: loss 0.533, accuracy 0.793
global_steps = [x['global_step'] for x in metrics]
fig, axes = plt.subplots(ncols=2, figsize=(20,8))

for axes_index, metric_names in enumerate([['accuracy', 'precision', 'recall'],
                                            ['loss']]):
  for metric_name in metric_names:
    axes[axes_index].plot(global_steps, [x[metric_name] for x in metrics], label=metric_name)
  axes[axes_index].legend()
  axes[axes_index].set_xlabel("Global Step")

png

可以看到,损失迅速减小,而准确率迅速提高。我们绘制一些样本来检查预测与真实标签的关系:

predictions = estimator.predict(functools.partial(input_fn_predict, params))
first_10_predictions = list(itertools.islice(predictions, 10))

display_df(
  pd.DataFrame({
      TEXT_FEATURE_NAME: [pred['features'].decode('utf8') for pred in first_10_predictions],
      LABEL_NAME: [THE_DATASET.class_names()[pred['labels']] for pred in first_10_predictions],
      'prediction': [THE_DATASET.class_names()[pred['predictions']] for pred in first_10_predictions]
  }))
2022-06-03 15:59:03.340237: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 27 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
/tmpfs/tmp/ipykernel_7888/393120678.py:7: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  logits = tf.layers.dense(

可以看到,对于此随机样本,模型大多数时候都会预测正确的标签,这表明它可以很好地嵌入科学句子。

后续计划

现在,您已经对 TF-Hub 中的 CORD-19 Swivel 嵌入向量有了更多了解,我们鼓励您参加 CORD-19 Kaggle 竞赛,为从 COVID-19 相关学术文本中获得更深入的科学洞见做出贡献。