Erkunden der TF-Hub CORD-19 Swivel Embeddings

Auf TensorFlow.org ansehen In Google Colab ausführen Auf GitHub ansehen Notizbuch herunterladen Siehe TF Hub-Modell

Der Kord-19 Swivel Text Einbettungsmodul von TF-Hub (https: //tfhub.dev/ tensorflow / Schnur-19 / Schwenk-128d / 3) wurde entwickelt, um Forscher bei der Analyse von Texten in natürlichen Sprachen im Zusammenhang mit COVID-19 zu unterstützen. Diese Einbettungen wurden auf dem Titel, Autoren geschult, Abstracts, Körper Texten und Referenz Titel des Artikels in dem CORD-19 - Datensatz .

In dieser Kollaboration werden wir:

  • Analysieren Sie semantisch ähnliche Wörter im Einbettungsraum
  • Trainieren Sie einen Klassifikator für den SciCite-Datensatz mithilfe der CORD-19-Einbettungen

Installieren

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange
2021-07-29 12:25:53.619135: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0

Analysieren Sie die Einbettungen

Beginnen wir mit der Analyse der Einbettung, indem wir eine Korrelationsmatrix zwischen verschiedenen Termen berechnen und zeichnen. Wenn die Einbettung gelernt hat, die Bedeutung verschiedener Wörter erfolgreich zu erfassen, sollten die Einbettungsvektoren semantisch ähnlicher Wörter nahe beieinander liegen. Werfen wir einen Blick auf einige COVID-19-bezogene Begriffe.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)
2021-07-29 12:25:59.080690: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-07-29 12:25:59.810158: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.811201: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-07-29 12:25:59.811244: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-07-29 12:25:59.814240: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-07-29 12:25:59.814353: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2021-07-29 12:25:59.815310: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcufft.so.10
2021-07-29 12:25:59.815702: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcurand.so.10
2021-07-29 12:25:59.816494: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11
2021-07-29 12:25:59.817238: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusparse.so.11
2021-07-29 12:25:59.817439: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2021-07-29 12:25:59.817548: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.818653: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.819625: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-07-29 12:25:59.820194: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-07-29 12:25:59.820726: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.821772: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-07-29 12:25:59.821868: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.822831: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:25:59.823718: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-07-29 12:25:59.823793: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-07-29 12:26:00.432478: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-07-29 12:26:00.432518: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264]      0 
2021-07-29 12:26:00.432527: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0:   N 
2021-07-29 12:26:00.432785: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:26:00.433764: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:26:00.434667: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-29 12:26:00.435561: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 14646 MB memory) -> physical GPU (device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0)
2021-07-29 12:26:00.770580: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-07-29 12:26:00.771193: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2000179999 Hz

png

Wir können sehen, dass die Einbettung die Bedeutung der verschiedenen Begriffe erfolgreich erfasst hat. Jedes Wort ist den anderen Wörtern seines Clusters ähnlich (dh "Coronavirus" korreliert stark mit "SARS" und "MERS"), während sie sich von Begriffen anderer Cluster unterscheiden (dh die Ähnlichkeit zwischen "SARS" und "Spanien" ist nahe 0).

Sehen wir uns nun an, wie wir diese Einbettungen verwenden können, um eine bestimmte Aufgabe zu lösen.

SciCite: Klassifikation der Zitationsabsichten

Dieser Abschnitt zeigt, wie man die Einbettung für nachgelagerte Aufgaben wie die Textklassifizierung verwenden kann. Wir werden die Verwendung SciCite - Datensatz von TensorFlow Datensätze zu klassifizieren Zitat Absichten in wissenschaftlichen Arbeiten. Klassifizieren Sie bei einem Satz mit Zitat aus einer wissenschaftlichen Arbeit, ob das Zitat hauptsächlich Hintergrundinformationen, Methodenanwendung oder Vergleich von Ergebnissen ist.

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

Schauen wir uns ein paar beschriftete Beispiele aus dem Trainingsset an

NUM_EXAMPLES =   10

TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]

def label2str(numeric_label):
  m = builder.info.features[LABEL_NAME].names
  return m[numeric_label]

data = next(iter(train_data.batch(NUM_EXAMPLES)))


pd.DataFrame({
    TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
    LABEL_NAME: [label2str(x) for x in data[1]]
})

Trainieren eines Citaton-Intent-Klassifikators

Wir werden einen Klassifikator auf den Zug SciCite Dataset Keras verwenden. Lassen Sie uns ein Modell erstellen, das die CORD-19-Einbettungen mit einer darüber liegenden Klassifizierungsebene verwendet.

Hyperparameter

EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3' 
TRAINABLE_MODULE = False 

hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[], 
                           dtype=tf.string, trainable=TRAINABLE_MODULE)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 128)               17301632  
_________________________________________________________________
dense (Dense)                (None, 3)                 387       
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

Trainieren und bewerten Sie das Modell

Lassen Sie uns das Modell trainieren und bewerten, um die Leistung der SciCite-Aufgabe zu sehen

EPOCHS =   35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
2021-07-29 12:26:08.694142: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
42/257 [===>..........................] - ETA: 0s - loss: 1.1140 - accuracy: 0.4077
2021-07-29 12:26:09.103023: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
257/257 [==============================] - 3s 8ms/step - loss: 0.9001 - accuracy: 0.5791 - val_loss: 0.7704 - val_accuracy: 0.6714
Epoch 2/35
257/257 [==============================] - 2s 6ms/step - loss: 0.6927 - accuracy: 0.7199 - val_loss: 0.6645 - val_accuracy: 0.7489
Epoch 3/35
257/257 [==============================] - 3s 7ms/step - loss: 0.6213 - accuracy: 0.7625 - val_loss: 0.6254 - val_accuracy: 0.7555
Epoch 4/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5876 - accuracy: 0.7761 - val_loss: 0.6025 - val_accuracy: 0.7642
Epoch 5/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5677 - accuracy: 0.7826 - val_loss: 0.5875 - val_accuracy: 0.7642
Epoch 6/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5547 - accuracy: 0.7851 - val_loss: 0.5776 - val_accuracy: 0.7740
Epoch 7/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5457 - accuracy: 0.7867 - val_loss: 0.5709 - val_accuracy: 0.7817
Epoch 8/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5393 - accuracy: 0.7906 - val_loss: 0.5669 - val_accuracy: 0.7784
Epoch 9/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5341 - accuracy: 0.7914 - val_loss: 0.5649 - val_accuracy: 0.7828
Epoch 10/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5293 - accuracy: 0.7911 - val_loss: 0.5597 - val_accuracy: 0.7860
Epoch 11/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5259 - accuracy: 0.7922 - val_loss: 0.5569 - val_accuracy: 0.7828
Epoch 12/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5230 - accuracy: 0.7948 - val_loss: 0.5559 - val_accuracy: 0.7806
Epoch 13/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5200 - accuracy: 0.7940 - val_loss: 0.5539 - val_accuracy: 0.7795
Epoch 14/35
257/257 [==============================] - 2s 7ms/step - loss: 0.5179 - accuracy: 0.7957 - val_loss: 0.5521 - val_accuracy: 0.7806
Epoch 15/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5160 - accuracy: 0.7952 - val_loss: 0.5530 - val_accuracy: 0.7838
Epoch 16/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5146 - accuracy: 0.7961 - val_loss: 0.5525 - val_accuracy: 0.7740
Epoch 17/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5128 - accuracy: 0.7975 - val_loss: 0.5508 - val_accuracy: 0.7828
Epoch 18/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5116 - accuracy: 0.7974 - val_loss: 0.5504 - val_accuracy: 0.7773
Epoch 19/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5101 - accuracy: 0.7983 - val_loss: 0.5502 - val_accuracy: 0.7784
Epoch 20/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5095 - accuracy: 0.7972 - val_loss: 0.5477 - val_accuracy: 0.7838
Epoch 21/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5079 - accuracy: 0.7988 - val_loss: 0.5476 - val_accuracy: 0.7817
Epoch 22/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5069 - accuracy: 0.8003 - val_loss: 0.5476 - val_accuracy: 0.7828
Epoch 23/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5060 - accuracy: 0.7990 - val_loss: 0.5495 - val_accuracy: 0.7806
Epoch 24/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5052 - accuracy: 0.7999 - val_loss: 0.5479 - val_accuracy: 0.7828
Epoch 25/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5047 - accuracy: 0.7986 - val_loss: 0.5473 - val_accuracy: 0.7828
Epoch 26/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5035 - accuracy: 0.8006 - val_loss: 0.5464 - val_accuracy: 0.7849
Epoch 27/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5032 - accuracy: 0.7990 - val_loss: 0.5461 - val_accuracy: 0.7849
Epoch 28/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5029 - accuracy: 0.8013 - val_loss: 0.5460 - val_accuracy: 0.7860
Epoch 29/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5019 - accuracy: 0.8003 - val_loss: 0.5436 - val_accuracy: 0.7882
Epoch 30/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5022 - accuracy: 0.8006 - val_loss: 0.5447 - val_accuracy: 0.7882
Epoch 31/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5009 - accuracy: 0.8010 - val_loss: 0.5475 - val_accuracy: 0.7893
Epoch 32/35
257/257 [==============================] - 2s 6ms/step - loss: 0.5005 - accuracy: 0.8035 - val_loss: 0.5477 - val_accuracy: 0.7860
Epoch 33/35
257/257 [==============================] - 2s 5ms/step - loss: 0.5002 - accuracy: 0.8021 - val_loss: 0.5433 - val_accuracy: 0.7893
Epoch 34/35
257/257 [==============================] - 2s 5ms/step - loss: 0.4998 - accuracy: 0.7996 - val_loss: 0.5473 - val_accuracy: 0.7882
Epoch 35/35
257/257 [==============================] - 2s 5ms/step - loss: 0.4993 - accuracy: 0.8012 - val_loss: 0.5449 - val_accuracy: 0.7904
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

png

Bewerten Sie das Modell

Und mal sehen, wie sich das Modell schlägt. Es werden zwei Werte zurückgegeben. Verlust (eine Zahl, die unseren Fehler darstellt, niedrigere Werte sind besser) und Genauigkeit.

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5342 - accuracy: 0.7929
loss: 0.534
accuracy: 0.793

Wir sehen, dass der Verlust schnell abnimmt, während vor allem die Genauigkeit schnell zunimmt. Lassen Sie uns einige Beispiele grafisch darstellen, um zu überprüfen, wie sich die Vorhersage auf die wahren Labels bezieht:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [label2str(x) for x in model.predict_classes(prediction_texts)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py:455: UserWarning: `model.predict_classes()` is deprecated and will be removed after 2021-01-01. Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).
  warnings.warn('`model.predict_classes()` is deprecated and '

Wir können sehen, dass das Modell für diese Zufallsstichprobe meistens die richtige Bezeichnung vorhersagt, was darauf hindeutet, dass es wissenschaftliche Sätze ziemlich gut einbetten kann.

Was kommt als nächstes?

Nachdem Sie nun etwas mehr über die CORD-19 Swivel-Einbettungen von TF-Hub erfahren haben, empfehlen wir Ihnen, am CORD-19 Kaggle-Wettbewerb teilzunehmen, um dazu beizutragen, wissenschaftliche Erkenntnisse aus COVID-19-bezogenen akademischen Texten zu gewinnen.