Cette page a été traduite par l'API Cloud Translation.
Switch to English

Explorer les embeddings pivotants TF-Hub CORD-19

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher sur GitHub Télécharger le carnet Voir le modèle TF Hub

Le module d'intégration de texte CORD-19 Swivel de TF-Hub (https: //tfhub.dev/tensorflow/cord-19/swivel-128d/3) a été conçu pour aider les chercheurs à analyser le texte des langues naturelles lié à COVID-19 [FEMININE. Ces incorporations ont été formées sur les titres, les auteurs, les résumés, les corps du texte et les titres de référence des articles dans l' ensemble de données CORD-19 .

Dans ce colab, nous allons:

  • Analyser des mots sémantiquement similaires dans l'espace d'intégration
  • Former un classificateur sur l'ensemble de données SciCite à l'aide des plongements CORD-19

Installer

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange

Analyser les plongements

Commençons par analyser l'incorporation en calculant et en traçant une matrice de corrélation entre différents termes. Si l'incorporation a appris à capturer avec succès la signification de différents mots, les vecteurs d'incorporation de mots sémantiquement similaires devraient être rapprochés. Jetons un coup d'œil à quelques termes liés au COVID-19.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

png

Nous pouvons voir que l'intégration a réussi à saisir le sens des différents termes. Chaque mot est similaire aux autres mots de son cluster (c'est-à-dire que «coronavirus» est fortement corrélé avec «SRAS» et «MERS»), alors qu'ils sont différents des termes d'autres groupes (c'est-à-dire que la similitude entre «SRAS» et «Espagne» est proche de 0).

Voyons maintenant comment nous pouvons utiliser ces incorporations pour résoudre une tâche spécifique.

SciCite: Classification des intentions de citation

Cette section montre comment utiliser l'incorporation pour des tâches en aval telles que la classification de texte. Nous utiliserons l' ensemble de données SciCite des ensembles de données TensorFlow pour classer les intentions de citation dans les articles universitaires. Étant donné une phrase avec une citation d'un article académique, indiquez si l'intention principale de la citation est une information de base, l'utilisation de méthodes ou la comparaison des résultats.

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

NUM_EXAMPLES =   10

TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]

def label2str(numeric_label):
  m = builder.info.features[LABEL_NAME].names
  return m[numeric_label]

data = next(iter(train_data.batch(NUM_EXAMPLES)))


pd.DataFrame({
    TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
    LABEL_NAME: [label2str(x) for x in data[1]]
})

Formation d'un classificateur d'intention Citaton

Nous allons former un classificateur sur l' ensemble de données SciCite à l' aide de Keras. Construisons un modèle qui utilise les imbrications CORD-19 avec une couche de classification sur le dessus.



EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3'  
TRAINABLE_MODULE = False  

hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[], 
                           dtype=tf.string, trainable=TRAINABLE_MODULE)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7f3380141158> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

Warning:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7f3380141158> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

Warning:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


Warning:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 128)               17301632  
_________________________________________________________________
dense (Dense)                (None, 3)                 387       
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

Former et évaluer le modèle

Entraînons et évaluons le modèle pour voir les performances de la tâche SciCite

EPOCHS =   35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 1s 5ms/step - loss: 0.9008 - accuracy: 0.6061 - val_loss: 0.7859 - val_accuracy: 0.6834
Epoch 2/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6953 - accuracy: 0.7257 - val_loss: 0.6806 - val_accuracy: 0.7282
Epoch 3/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6240 - accuracy: 0.7556 - val_loss: 0.6368 - val_accuracy: 0.7533
Epoch 4/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5892 - accuracy: 0.7720 - val_loss: 0.6102 - val_accuracy: 0.7544
Epoch 5/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5691 - accuracy: 0.7769 - val_loss: 0.5952 - val_accuracy: 0.7631
Epoch 6/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5554 - accuracy: 0.7840 - val_loss: 0.5885 - val_accuracy: 0.7642
Epoch 7/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5466 - accuracy: 0.7856 - val_loss: 0.5779 - val_accuracy: 0.7675
Epoch 8/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5400 - accuracy: 0.7875 - val_loss: 0.5715 - val_accuracy: 0.7751
Epoch 9/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5344 - accuracy: 0.7911 - val_loss: 0.5681 - val_accuracy: 0.7740
Epoch 10/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5299 - accuracy: 0.7929 - val_loss: 0.5665 - val_accuracy: 0.7707
Epoch 11/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5264 - accuracy: 0.7922 - val_loss: 0.5617 - val_accuracy: 0.7762
Epoch 12/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5234 - accuracy: 0.7938 - val_loss: 0.5620 - val_accuracy: 0.7762
Epoch 13/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5202 - accuracy: 0.7950 - val_loss: 0.5603 - val_accuracy: 0.7740
Epoch 14/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5182 - accuracy: 0.7962 - val_loss: 0.5574 - val_accuracy: 0.7762
Epoch 15/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5165 - accuracy: 0.7959 - val_loss: 0.5561 - val_accuracy: 0.7828
Epoch 16/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5147 - accuracy: 0.7977 - val_loss: 0.5537 - val_accuracy: 0.7784
Epoch 17/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5132 - accuracy: 0.7974 - val_loss: 0.5541 - val_accuracy: 0.7817
Epoch 18/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5118 - accuracy: 0.7967 - val_loss: 0.5546 - val_accuracy: 0.7838
Epoch 19/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5103 - accuracy: 0.7992 - val_loss: 0.5527 - val_accuracy: 0.7871
Epoch 20/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5093 - accuracy: 0.7996 - val_loss: 0.5498 - val_accuracy: 0.7806
Epoch 21/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5076 - accuracy: 0.7988 - val_loss: 0.5527 - val_accuracy: 0.7784
Epoch 22/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5071 - accuracy: 0.7984 - val_loss: 0.5492 - val_accuracy: 0.7849
Epoch 23/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5063 - accuracy: 0.7969 - val_loss: 0.5484 - val_accuracy: 0.7828
Epoch 24/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5051 - accuracy: 0.7996 - val_loss: 0.5500 - val_accuracy: 0.7849
Epoch 25/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5052 - accuracy: 0.7977 - val_loss: 0.5475 - val_accuracy: 0.7860
Epoch 26/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5040 - accuracy: 0.8001 - val_loss: 0.5475 - val_accuracy: 0.7828
Epoch 27/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5034 - accuracy: 0.8008 - val_loss: 0.5488 - val_accuracy: 0.7838
Epoch 28/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5023 - accuracy: 0.8017 - val_loss: 0.5505 - val_accuracy: 0.7817
Epoch 29/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5024 - accuracy: 0.8014 - val_loss: 0.5486 - val_accuracy: 0.7849
Epoch 30/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5019 - accuracy: 0.8022 - val_loss: 0.5457 - val_accuracy: 0.7860
Epoch 31/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5014 - accuracy: 0.8021 - val_loss: 0.5448 - val_accuracy: 0.7838
Epoch 32/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5004 - accuracy: 0.8005 - val_loss: 0.5451 - val_accuracy: 0.7893
Epoch 33/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5004 - accuracy: 0.8014 - val_loss: 0.5484 - val_accuracy: 0.7871
Epoch 34/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5000 - accuracy: 0.8006 - val_loss: 0.5462 - val_accuracy: 0.7882
Epoch 35/35
257/257 [==============================] - 1s 5ms/step - loss: 0.4995 - accuracy: 0.8016 - val_loss: 0.5451 - val_accuracy: 0.7882

from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

png

Évaluer le modèle

Et voyons comment le modèle fonctionne. Deux valeurs seront renvoyées. Perte (un nombre qui représente notre erreur, les valeurs inférieures sont meilleures) et précision.

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5349 - accuracy: 0.7854
loss: 0.535
accuracy: 0.785

On voit que la perte diminue rapidement alors que surtout la précision augmente rapidement. Tracons quelques exemples pour vérifier comment la prédiction se rapporte aux vraies étiquettes:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [label2str(x) for x in model.predict_classes(prediction_texts)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

Warning:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

Nous pouvons voir que pour cet échantillon aléatoire, le modèle prédit l'étiquette correcte la plupart du temps, indiquant qu'il peut assez bien intégrer des phrases scientifiques.

Et après?

Maintenant que vous en savez un peu plus sur les intégrations CORD-19 Swivel de TF-Hub, nous vous encourageons à participer au concours CORD-19 Kaggle pour contribuer à obtenir des informations scientifiques à partir de textes académiques liés au COVID-19.