Exploración de las incrustaciones giratorias TF-Hub CORD-19

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno Ver modelo TF Hub

El módulo de incrustación de texto giratorio CORD-19 de TF-Hub (https: //tfhub.dev/tensorflow/cord-19/swivel-128d/3) se creó para ayudar a los investigadores a analizar texto en lenguajes naturales COVID-19. Estas incorporaciones se entrenaron en los títulos, autores, resúmenes, textos corporales y títulos de referencia de artículos en el conjunto de datos CORD-19 .

En este colab vamos a:

  • Analizar palabras semánticamente similares en el espacio de inserción.
  • Entrene un clasificador en el conjunto de datos de SciCite utilizando las incrustaciones de CORD-19

Configuración

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange

Analizar las incrustaciones

Comencemos analizando la incrustación calculando y trazando una matriz de correlación entre diferentes términos. Si la incrustación aprendió a capturar con éxito el significado de diferentes palabras, los vectores de incrustación de palabras semánticamente similares deberían estar muy juntos. Echemos un vistazo a algunos términos relacionados con COVID-19.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

png

Podemos ver que la incrustación capturó con éxito el significado de los diferentes términos. Cada palabra es similar a las otras palabras de su grupo (es decir, "coronavirus" tiene una alta correlación con "SARS" y "MERS"), mientras que son diferentes de los términos de otros grupos (es decir, la similitud entre "SARS" y "España" es cerca de 0).

Ahora veamos cómo podemos usar estas incrustaciones para resolver una tarea específica.

SciCite: Clasificación por intención de citas

Esta sección muestra cómo se puede utilizar la incrustación para tareas posteriores, como la clasificación de texto. Usaremos el conjunto de datos SciCite de TensorFlow Datasets para clasificar las intenciones de citas en artículos académicos. Dada una oración con una cita de un artículo académico, clasifique si la intención principal de la cita es como información de antecedentes, uso de métodos o comparación de resultados.

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

Echemos un vistazo a algunos ejemplos etiquetados del conjunto de capacitación.

NUM_EXAMPLES =   10

TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]

def label2str(numeric_label):
  m = builder.info.features[LABEL_NAME].names
  return m[numeric_label]

data = next(iter(train_data.batch(NUM_EXAMPLES)))


pd.DataFrame({
    TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
    LABEL_NAME: [label2str(x) for x in data[1]]
})

Entrenamiento de un clasificador de intenciones de citaton

Entrenaremos un clasificador en el conjunto de datos SciCite usando Keras. Construyamos un modelo que use las incrustaciones de CORD-19 con una capa de clasificación en la parte superior.

Hiperparámetros

EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3' 
TRAINABLE_MODULE = False 

hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[], 
                           dtype=tf.string, trainable=TRAINABLE_MODULE)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7fe7fc16cf28> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7fe7fc16cf28> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 128)               17301632  
_________________________________________________________________
dense (Dense)                (None, 3)                 387       
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

Entrenar y evaluar el modelo

Entrenemos y evaluemos el modelo para ver el desempeño en la tarea SciCite

EPOCHS =   35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 1s 5ms/step - loss: 0.8978 - accuracy: 0.6190 - val_loss: 0.7777 - val_accuracy: 0.6736
Epoch 2/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6920 - accuracy: 0.7241 - val_loss: 0.6764 - val_accuracy: 0.7227
Epoch 3/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6215 - accuracy: 0.7588 - val_loss: 0.6296 - val_accuracy: 0.7434
Epoch 4/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5867 - accuracy: 0.7757 - val_loss: 0.6071 - val_accuracy: 0.7489
Epoch 5/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5674 - accuracy: 0.7808 - val_loss: 0.5923 - val_accuracy: 0.7576
Epoch 6/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5549 - accuracy: 0.7874 - val_loss: 0.5818 - val_accuracy: 0.7664
Epoch 7/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5457 - accuracy: 0.7890 - val_loss: 0.5756 - val_accuracy: 0.7664
Epoch 8/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5392 - accuracy: 0.7883 - val_loss: 0.5698 - val_accuracy: 0.7740
Epoch 9/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5335 - accuracy: 0.7907 - val_loss: 0.5643 - val_accuracy: 0.7707
Epoch 10/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5294 - accuracy: 0.7934 - val_loss: 0.5609 - val_accuracy: 0.7751
Epoch 11/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5260 - accuracy: 0.7933 - val_loss: 0.5591 - val_accuracy: 0.7762
Epoch 12/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5231 - accuracy: 0.7929 - val_loss: 0.5578 - val_accuracy: 0.7773
Epoch 13/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5202 - accuracy: 0.7933 - val_loss: 0.5586 - val_accuracy: 0.7751
Epoch 14/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5181 - accuracy: 0.7938 - val_loss: 0.5536 - val_accuracy: 0.7806
Epoch 15/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5164 - accuracy: 0.7944 - val_loss: 0.5514 - val_accuracy: 0.7806
Epoch 16/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5144 - accuracy: 0.7961 - val_loss: 0.5513 - val_accuracy: 0.7751
Epoch 17/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5128 - accuracy: 0.7967 - val_loss: 0.5526 - val_accuracy: 0.7773
Epoch 18/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5113 - accuracy: 0.7974 - val_loss: 0.5524 - val_accuracy: 0.7784
Epoch 19/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5104 - accuracy: 0.7972 - val_loss: 0.5480 - val_accuracy: 0.7817
Epoch 20/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5094 - accuracy: 0.7977 - val_loss: 0.5479 - val_accuracy: 0.7817
Epoch 21/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5077 - accuracy: 0.7981 - val_loss: 0.5480 - val_accuracy: 0.7838
Epoch 22/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5072 - accuracy: 0.7983 - val_loss: 0.5450 - val_accuracy: 0.7817
Epoch 23/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5059 - accuracy: 0.7994 - val_loss: 0.5495 - val_accuracy: 0.7795
Epoch 24/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5056 - accuracy: 0.7991 - val_loss: 0.5471 - val_accuracy: 0.7817
Epoch 25/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5048 - accuracy: 0.7983 - val_loss: 0.5474 - val_accuracy: 0.7806
Epoch 26/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5042 - accuracy: 0.8000 - val_loss: 0.5481 - val_accuracy: 0.7828
Epoch 27/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5031 - accuracy: 0.7996 - val_loss: 0.5447 - val_accuracy: 0.7882
Epoch 28/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5027 - accuracy: 0.8007 - val_loss: 0.5440 - val_accuracy: 0.7860
Epoch 29/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5018 - accuracy: 0.8024 - val_loss: 0.5460 - val_accuracy: 0.7871
Epoch 30/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5017 - accuracy: 0.8002 - val_loss: 0.5454 - val_accuracy: 0.7893
Epoch 31/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5013 - accuracy: 0.8008 - val_loss: 0.5456 - val_accuracy: 0.7828
Epoch 32/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5003 - accuracy: 0.8024 - val_loss: 0.5467 - val_accuracy: 0.7871
Epoch 33/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4999 - accuracy: 0.8007 - val_loss: 0.5445 - val_accuracy: 0.7915
Epoch 34/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4999 - accuracy: 0.8025 - val_loss: 0.5445 - val_accuracy: 0.7893
Epoch 35/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4995 - accuracy: 0.8010 - val_loss: 0.5453 - val_accuracy: 0.7828
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

png

Evaluar el modelo

Y veamos cómo funciona el modelo. Se devolverán dos valores. Pérdida (un número que representa nuestro error, los valores más bajos son mejores) y precisión.

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5357 - accuracy: 0.7918
loss: 0.536
accuracy: 0.792

Podemos ver que la pérdida disminuye rápidamente mientras que, especialmente, la precisión aumenta rápidamente. Tracemos algunos ejemplos para comprobar cómo se relaciona la predicción con las etiquetas verdaderas:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [label2str(x) for x in model.predict_classes(prediction_texts)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).
WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

Podemos ver que para esta muestra aleatoria, el modelo predice la etiqueta correcta la mayoría de las veces, lo que indica que puede incorporar oraciones científicas bastante bien.

¿Que sigue?

Ahora que ha aprendido un poco más sobre las incrustaciones CORD-19 Swivel de TF-Hub, lo alentamos a participar en la competencia CORD-19 Kaggle para contribuir a obtener conocimientos científicos de los textos académicos relacionados con COVID-19.