Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Erkundung der schwenkbaren Einbettungen des TF-Hub CORD-19

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Das CORD-19 Swivel-Text-Einbettungsmodul von TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/3) wurde entwickelt, um Forscher bei der Analyse von Text in natürlichen Sprachen zu unterstützen COVID-19. Diese Einbettungen wurden auf die Titel, Autoren, Abstracts, Körpertexte und Referenztitel von Artikeln im CORD-19-Datensatz trainiert.

In diesem Colab werden wir:

  • Analysieren Sie semantisch ähnliche Wörter im Einbettungsbereich
  • Trainieren Sie einen Klassifikator im SciCite-Dataset mithilfe der CORD-19-Einbettungen

Konfiguration

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange

Analysieren Sie die Einbettungen

Beginnen wir mit der Analyse der Einbettung, indem wir eine Korrelationsmatrix zwischen verschiedenen Begriffen berechnen und zeichnen. Wenn die Einbettung gelernt hat, die Bedeutung verschiedener Wörter erfolgreich zu erfassen, sollten die Einbettungsvektoren semantisch ähnlicher Wörter nahe beieinander liegen. Schauen wir uns einige COVID-19-bezogene Begriffe an.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

png

Wir können sehen, dass die Einbettung die Bedeutung der verschiedenen Begriffe erfolgreich erfasst hat. Jedes Wort ähnelt den anderen Wörtern seines Clusters (dh "Coronavirus" korreliert stark mit "SARS" und "MERS"), während sie sich von Begriffen anderer Cluster unterscheiden (dh die Ähnlichkeit zwischen "SARS" und "Spanien" ist nahe 0).

Nun wollen wir sehen, wie wir diese Einbettungen verwenden können, um eine bestimmte Aufgabe zu lösen.

SciCite: Citation Intent Classification

Dieser Abschnitt zeigt, wie man die Einbettung für nachgelagerte Aufgaben wie die Textklassifizierung verwenden kann. Wir werden den SciCite-Datensatz aus TensorFlow-Datensätzen verwenden, um Zitierabsichten in wissenschaftlichen Arbeiten zu klassifizieren. Klassifizieren Sie anhand eines Satzes mit einem Zitat aus einer wissenschaftlichen Arbeit, ob die Hauptabsicht des Zitats Hintergrundinformationen, die Verwendung von Methoden oder der Vergleich von Ergebnissen sind.

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

NUM_EXAMPLES =   10

TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]

def label2str(numeric_label):
  m = builder.info.features[LABEL_NAME].names
  return m[numeric_label]

data = next(iter(train_data.batch(NUM_EXAMPLES)))


pd.DataFrame({
    TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
    LABEL_NAME: [label2str(x) for x in data[1]]
})

Training eines Citaton Intent Classifier

Wir werden einen Klassifikator für das SciCite-Dataset mit Keras trainieren . Erstellen wir ein Modell, das die CORD-19-Einbettungen mit einer Klassifizierungsebene darüber verwendet.



EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3'  
TRAINABLE_MODULE = False  

hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[], 
                           dtype=tf.string, trainable=TRAINABLE_MODULE)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


Warning:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 128)               17301632  
_________________________________________________________________
dense (Dense)                (None, 3)                 387       
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

Trainieren und bewerten Sie das Modell

Lassen Sie uns das Modell trainieren und bewerten, um die Leistung der SciCite-Aufgabe zu sehen

EPOCHS =   35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 1s 5ms/step - loss: 0.9005 - accuracy: 0.6104 - val_loss: 0.7803 - val_accuracy: 0.6769
Epoch 2/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6957 - accuracy: 0.7224 - val_loss: 0.6740 - val_accuracy: 0.7424
Epoch 3/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6224 - accuracy: 0.7571 - val_loss: 0.6289 - val_accuracy: 0.7522
Epoch 4/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5868 - accuracy: 0.7734 - val_loss: 0.6084 - val_accuracy: 0.7533
Epoch 5/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5676 - accuracy: 0.7795 - val_loss: 0.5894 - val_accuracy: 0.7609
Epoch 6/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5549 - accuracy: 0.7823 - val_loss: 0.5802 - val_accuracy: 0.7609
Epoch 7/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5458 - accuracy: 0.7873 - val_loss: 0.5805 - val_accuracy: 0.7620
Epoch 8/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5393 - accuracy: 0.7875 - val_loss: 0.5697 - val_accuracy: 0.7729
Epoch 9/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5337 - accuracy: 0.7906 - val_loss: 0.5659 - val_accuracy: 0.7751
Epoch 10/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5295 - accuracy: 0.7907 - val_loss: 0.5614 - val_accuracy: 0.7784
Epoch 11/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5257 - accuracy: 0.7942 - val_loss: 0.5630 - val_accuracy: 0.7740
Epoch 12/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5229 - accuracy: 0.7945 - val_loss: 0.5582 - val_accuracy: 0.7762
Epoch 13/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5202 - accuracy: 0.7944 - val_loss: 0.5576 - val_accuracy: 0.7762
Epoch 14/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5183 - accuracy: 0.7947 - val_loss: 0.5530 - val_accuracy: 0.7762
Epoch 15/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5158 - accuracy: 0.7966 - val_loss: 0.5537 - val_accuracy: 0.7762
Epoch 16/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5144 - accuracy: 0.7972 - val_loss: 0.5525 - val_accuracy: 0.7784
Epoch 17/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5129 - accuracy: 0.7968 - val_loss: 0.5496 - val_accuracy: 0.7795
Epoch 18/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5115 - accuracy: 0.7978 - val_loss: 0.5525 - val_accuracy: 0.7795
Epoch 19/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5104 - accuracy: 0.7985 - val_loss: 0.5482 - val_accuracy: 0.7795
Epoch 20/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5091 - accuracy: 0.7981 - val_loss: 0.5490 - val_accuracy: 0.7806
Epoch 21/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5082 - accuracy: 0.7985 - val_loss: 0.5489 - val_accuracy: 0.7838
Epoch 22/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5073 - accuracy: 0.7990 - val_loss: 0.5487 - val_accuracy: 0.7806
Epoch 23/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5063 - accuracy: 0.7989 - val_loss: 0.5459 - val_accuracy: 0.7860
Epoch 24/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5055 - accuracy: 0.7990 - val_loss: 0.5469 - val_accuracy: 0.7849
Epoch 25/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5045 - accuracy: 0.7989 - val_loss: 0.5463 - val_accuracy: 0.7849
Epoch 26/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5042 - accuracy: 0.8001 - val_loss: 0.5463 - val_accuracy: 0.7806
Epoch 27/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5032 - accuracy: 0.8002 - val_loss: 0.5484 - val_accuracy: 0.7882
Epoch 28/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5033 - accuracy: 0.8012 - val_loss: 0.5483 - val_accuracy: 0.7860
Epoch 29/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5019 - accuracy: 0.8016 - val_loss: 0.5455 - val_accuracy: 0.7838
Epoch 30/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5014 - accuracy: 0.8024 - val_loss: 0.5477 - val_accuracy: 0.7828
Epoch 31/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5014 - accuracy: 0.8017 - val_loss: 0.5495 - val_accuracy: 0.7871
Epoch 32/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5010 - accuracy: 0.8013 - val_loss: 0.5438 - val_accuracy: 0.7882
Epoch 33/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5003 - accuracy: 0.8016 - val_loss: 0.5452 - val_accuracy: 0.7893
Epoch 34/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5002 - accuracy: 0.8016 - val_loss: 0.5450 - val_accuracy: 0.7871
Epoch 35/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4995 - accuracy: 0.8018 - val_loss: 0.5461 - val_accuracy: 0.7893

from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

png

Bewerten Sie das Modell

Und mal sehen, wie das Modell funktioniert. Es werden zwei Werte zurückgegeben. Verlust (eine Zahl, die unseren Fehler darstellt, niedrigere Werte sind besser) und Genauigkeit.

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5344 - accuracy: 0.7924
loss: 0.534
accuracy: 0.792

Wir können sehen, dass der Verlust schnell abnimmt, während insbesondere die Genauigkeit schnell zunimmt. Lassen Sie uns einige Beispiele darstellen, um zu überprüfen, wie sich die Vorhersage auf die wahren Bezeichnungen bezieht:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [label2str(x) for x in model.predict_classes(prediction_texts)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
WARNING:tensorflow:From <ipython-input-11-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

Warning:tensorflow:From <ipython-input-11-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

Wir können sehen, dass das Modell für diese Zufallsstichprobe meistens die richtige Bezeichnung vorhersagt, was darauf hinweist, dass es wissenschaftliche Sätze ziemlich gut einbetten kann.

Was kommt als nächstes?

Nachdem Sie etwas mehr über die CORD-19 Swivel-Einbettungen von TF-Hub erfahren haben, empfehlen wir Ihnen, am CORD-19 Kaggle-Wettbewerb teilzunehmen, um wissenschaftliche Erkenntnisse aus COVID-19-bezogenen akademischen Texten zu gewinnen.