RSVP для вашего местного мероприятия TensorFlow Everywhere сегодня!
Эта страница переведена с помощью Cloud Translation API.
Switch to English

Изучение поворотных встраиваний TF-Hub CORD-19

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть на GitHub Скачать блокнот См. Модель TF Hub

Модуль встраивания текста CORD-19 Swivel от TF-Hub (https: //tfhub.dev/tensorflow/cord-19/swivel-128d/3) был создан для поддержки исследователей, анализирующих текст на естественном языке, связанный с COVID-19. Эти вложения были обучены названиям, авторам, рефератам, основным текстам и ссылочным названиям статей в наборе данных CORD-19 .

В этом колабе мы:

  • Анализировать семантически похожие слова в пространстве вложения
  • Обучите классификатор на наборе данных SciCite, используя вложения CORD-19

Настроить

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange

Анализируйте вложения

Давайте начнем с анализа встраивания путем вычисления и построения корреляционной матрицы между различными терминами. Если встраивание научилось успешно улавливать значение разных слов, векторы встраивания семантически похожих слов должны быть близко друг к другу. Давайте посмотрим на некоторые термины, связанные с COVID-19.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

PNG

Мы видим, что встраивание успешно уловило значение различных терминов. Каждое слово похоже на другие слова своего кластера (например, «коронавирус» сильно коррелирует с «SARS» и «MERS»), в то время как они отличаются от терминов других кластеров (например, сходство между «SARS» и «Испания» составляет близко к 0).

Теперь давайте посмотрим, как мы можем использовать эти вложения для решения конкретной задачи.

SciCite: классификация намерений цитирования

В этом разделе показано, как можно использовать встраивание для последующих задач, таких как классификация текста. Мы будем использовать набор данных SciCite из TensorFlow Datasets для классификации целей цитирования в академических статьях. Учитывая предложение с цитатой из академической статьи, классифицируйте, является ли основная цель цитирования справочной информацией, использованием методов или сравнением результатов.

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

Давайте посмотрим на несколько помеченных примеров из обучающей выборки.

NUM_EXAMPLES =   10

TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]

def label2str(numeric_label):
  m = builder.info.features[LABEL_NAME].names
  return m[numeric_label]

data = next(iter(train_data.batch(NUM_EXAMPLES)))


pd.DataFrame({
    TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
    LABEL_NAME: [label2str(x) for x in data[1]]
})

Обучение классификатора намерений citaton

Мы обучим классификатор на наборе данных SciCite с помощью Keras . Давайте построим модель, которая использует вложения CORD-19 с классификационным слоем наверху.

Гиперпараметры

EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3' 
TRAINABLE_MODULE = False 

hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[], 
                           dtype=tf.string, trainable=TRAINABLE_MODULE)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7fe7fc16cf28> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7fe7fc16cf28> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 128)               17301632  
_________________________________________________________________
dense (Dense)                (None, 3)                 387       
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

Обучите и оцените модель

Давайте обучим и оценим модель, чтобы увидеть производительность задачи SciCite.

EPOCHS =   35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 1s 5ms/step - loss: 0.8978 - accuracy: 0.6190 - val_loss: 0.7777 - val_accuracy: 0.6736
Epoch 2/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6920 - accuracy: 0.7241 - val_loss: 0.6764 - val_accuracy: 0.7227
Epoch 3/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6215 - accuracy: 0.7588 - val_loss: 0.6296 - val_accuracy: 0.7434
Epoch 4/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5867 - accuracy: 0.7757 - val_loss: 0.6071 - val_accuracy: 0.7489
Epoch 5/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5674 - accuracy: 0.7808 - val_loss: 0.5923 - val_accuracy: 0.7576
Epoch 6/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5549 - accuracy: 0.7874 - val_loss: 0.5818 - val_accuracy: 0.7664
Epoch 7/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5457 - accuracy: 0.7890 - val_loss: 0.5756 - val_accuracy: 0.7664
Epoch 8/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5392 - accuracy: 0.7883 - val_loss: 0.5698 - val_accuracy: 0.7740
Epoch 9/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5335 - accuracy: 0.7907 - val_loss: 0.5643 - val_accuracy: 0.7707
Epoch 10/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5294 - accuracy: 0.7934 - val_loss: 0.5609 - val_accuracy: 0.7751
Epoch 11/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5260 - accuracy: 0.7933 - val_loss: 0.5591 - val_accuracy: 0.7762
Epoch 12/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5231 - accuracy: 0.7929 - val_loss: 0.5578 - val_accuracy: 0.7773
Epoch 13/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5202 - accuracy: 0.7933 - val_loss: 0.5586 - val_accuracy: 0.7751
Epoch 14/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5181 - accuracy: 0.7938 - val_loss: 0.5536 - val_accuracy: 0.7806
Epoch 15/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5164 - accuracy: 0.7944 - val_loss: 0.5514 - val_accuracy: 0.7806
Epoch 16/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5144 - accuracy: 0.7961 - val_loss: 0.5513 - val_accuracy: 0.7751
Epoch 17/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5128 - accuracy: 0.7967 - val_loss: 0.5526 - val_accuracy: 0.7773
Epoch 18/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5113 - accuracy: 0.7974 - val_loss: 0.5524 - val_accuracy: 0.7784
Epoch 19/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5104 - accuracy: 0.7972 - val_loss: 0.5480 - val_accuracy: 0.7817
Epoch 20/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5094 - accuracy: 0.7977 - val_loss: 0.5479 - val_accuracy: 0.7817
Epoch 21/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5077 - accuracy: 0.7981 - val_loss: 0.5480 - val_accuracy: 0.7838
Epoch 22/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5072 - accuracy: 0.7983 - val_loss: 0.5450 - val_accuracy: 0.7817
Epoch 23/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5059 - accuracy: 0.7994 - val_loss: 0.5495 - val_accuracy: 0.7795
Epoch 24/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5056 - accuracy: 0.7991 - val_loss: 0.5471 - val_accuracy: 0.7817
Epoch 25/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5048 - accuracy: 0.7983 - val_loss: 0.5474 - val_accuracy: 0.7806
Epoch 26/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5042 - accuracy: 0.8000 - val_loss: 0.5481 - val_accuracy: 0.7828
Epoch 27/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5031 - accuracy: 0.7996 - val_loss: 0.5447 - val_accuracy: 0.7882
Epoch 28/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5027 - accuracy: 0.8007 - val_loss: 0.5440 - val_accuracy: 0.7860
Epoch 29/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5018 - accuracy: 0.8024 - val_loss: 0.5460 - val_accuracy: 0.7871
Epoch 30/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5017 - accuracy: 0.8002 - val_loss: 0.5454 - val_accuracy: 0.7893
Epoch 31/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5013 - accuracy: 0.8008 - val_loss: 0.5456 - val_accuracy: 0.7828
Epoch 32/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5003 - accuracy: 0.8024 - val_loss: 0.5467 - val_accuracy: 0.7871
Epoch 33/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4999 - accuracy: 0.8007 - val_loss: 0.5445 - val_accuracy: 0.7915
Epoch 34/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4999 - accuracy: 0.8025 - val_loss: 0.5445 - val_accuracy: 0.7893
Epoch 35/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4995 - accuracy: 0.8010 - val_loss: 0.5453 - val_accuracy: 0.7828

from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

PNG

Оцените модель

И посмотрим, как модель работает. Будут возвращены два значения. Потеря (число, которое представляет нашу ошибку, меньшие значения лучше) и точность.

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5357 - accuracy: 0.7918
loss: 0.536
accuracy: 0.792

Мы видим, что потери быстро уменьшаются, особенно быстро увеличивается точность. Давайте построим несколько примеров, чтобы проверить, как прогноз соотносится с истинными метками:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [label2str(x) for x in model.predict_classes(prediction_texts)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

Мы видим, что для этой случайной выборки модель в большинстве случаев предсказывает правильную метку, что указывает на то, что она может довольно хорошо встраивать научные предложения.

Что дальше?

Теперь, когда вы узнали немного больше о встраиваниях CORD-19 Swivel от TF-Hub, мы приглашаем вас принять участие в конкурсе CORD-19 Kaggle, чтобы внести свой вклад в получение научных идей из академических текстов, связанных с COVID-19.