Esta página foi traduzida pela API Cloud Translation.
Switch to English

Explorando o TF-Hub CORD-19 Swivel Embeddings

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

O módulo de incorporação de texto CORD-19 Swivel do TF-Hub (https: //tfhub.dev/tensorflow/cord-19/swivel-128d/1) foi criado para apoiar pesquisadores que analisam textos de línguas naturais relacionados a COVID-19. Esses embeddings foram treinados nos títulos, autores, resumos, textos do corpo e títulos de referência de artigos no conjunto de dados CORD-19 .

Nesta colab iremos:

  • Analise palavras semanticamente semelhantes no espaço de incorporação
  • Treine um classificador no conjunto de dados SciCite usando os embeddings CORD-19

Configuração

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.logging.set_verbosity('ERROR')

import tensorflow_datasets as tfds
import tensorflow_hub as hub

try:
  from google.colab import data_table
  def display_df(df):
    return data_table.DataTable(df, include_index=False)
except ModuleNotFoundError:
  # If google-colab is not available, just display the raw DataFrame
  def display_df(df):
    return df

Analise os embeddings

Vamos começar analisando a incorporação calculando e plotando uma matriz de correlação entre os diferentes termos. Se o embedding aprender a capturar com sucesso o significado de palavras diferentes, os vetores de embedding de palavras semanticamente semelhantes devem estar próximos. Vamos dar uma olhada em alguns termos relacionados ao COVID-19.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)


with tf.Graph().as_default():
  # Load the module
  query_input = tf.placeholder(tf.string)
  module = hub.Module('https://tfhub.dev/tensorflow/cord-19/swivel-128d/1')
  embeddings = module(query_input)

  with tf.train.MonitoredTrainingSession() as sess:

    # Generate embeddings for some terms
    queries = [
        # Related viruses
        "coronavirus", "SARS", "MERS",
        # Regions
        "Italy", "Spain", "Europe",
        # Symptoms
        "cough", "fever", "throat"
    ]

    features = sess.run(embeddings, feed_dict={query_input: queries})
    plot_correlation(queries, features)

png

Podemos ver que a incorporação capturou com sucesso o significado dos diferentes termos. Cada palavra é semelhante às outras palavras do seu cluster (ou seja, "coronavirus" altamente correlacionado com "SARS" e "MERS"), embora sejam diferentes dos termos de outros clusters (ou seja, a semelhança entre "SARS" e "Espanha" é próximo a 0).

Agora vamos ver como podemos usar esses embeddings para resolver uma tarefa específica.

SciCite: Classificação da intenção de citação

Esta seção mostra como se pode usar a incorporação para tarefas posteriores, como classificação de texto. Usaremos o conjunto de dados SciCite dos conjuntos de dados TensorFlow para classificar intenções de citação em trabalhos acadêmicos. Dada uma frase com uma citação de um artigo acadêmico, classifique se o objetivo principal da citação é como informação de fundo, uso de métodos ou comparação de resultados.



class Dataset:
  """Build a dataset from a TFDS dataset."""
  def __init__(self, tfds_name, feature_name, label_name):
    self.dataset_builder = tfds.builder(tfds_name)
    self.dataset_builder.download_and_prepare()
    self.feature_name = feature_name
    self.label_name = label_name
  
  def get_data(self, for_eval):
    splits = THE_DATASET.dataset_builder.info.splits
    if tfds.Split.TEST in splits:
      split = tfds.Split.TEST if for_eval else tfds.Split.TRAIN
    else:
      SPLIT_PERCENT = 80
      split = "train[{}%:]".format(SPLIT_PERCENT) if for_eval else "train[:{}%]".format(SPLIT_PERCENT)
    return self.dataset_builder.as_dataset(split=split)

  def num_classes(self):
    return self.dataset_builder.info.features[self.label_name].num_classes

  def class_names(self):
    return self.dataset_builder.info.features[self.label_name].names

  def preprocess_fn(self, data):
    return data[self.feature_name], data[self.label_name]

  def example_fn(self, data):
    feature, label = self.preprocess_fn(data)
    return {'feature': feature, 'label': label}, label


def get_example_data(dataset, num_examples, **data_kw):
  """Show example data"""
  with tf.Session() as sess:
    batched_ds = dataset.get_data(**data_kw).take(num_examples).map(dataset.preprocess_fn).batch(num_examples)
    it = tf.data.make_one_shot_iterator(batched_ds).get_next()
    data = sess.run(it)
  return data


TFDS_NAME = 'scicite' 
TEXT_FEATURE_NAME = 'string' 
LABEL_NAME = 'label' 
THE_DATASET = Dataset(TFDS_NAME, TEXT_FEATURE_NAME, LABEL_NAME)
Downloading and preparing dataset scicite/1.0.0 (download: 22.12 MiB, generated: Unknown size, total: 22.12 MiB) to /home/kbuilder/tensorflow_datasets/scicite/1.0.0...

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)








HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteEETG6U/scicite-train.tfrecord

HBox(children=(FloatProgress(value=0.0, max=8194.0), HTML(value='')))
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteEETG6U/scicite-validation.tfrecord

HBox(children=(FloatProgress(value=0.0, max=916.0), HTML(value='')))
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteEETG6U/scicite-test.tfrecord

HBox(children=(FloatProgress(value=0.0, max=1859.0), HTML(value='')))
Dataset scicite downloaded and prepared to /home/kbuilder/tensorflow_datasets/scicite/1.0.0. Subsequent calls will reuse this data.


NUM_EXAMPLES = 20  
data = get_example_data(THE_DATASET, NUM_EXAMPLES, for_eval=False)
display_df(
    pd.DataFrame({
        TEXT_FEATURE_NAME: [ex.decode('utf8') for ex in data[0]],
        LABEL_NAME: [THE_DATASET.class_names()[x] for x in data[1]]
    }))

Treinando um classificador de intenção de citaton

Vamos treinar um classificador no conjunto de dados SciCite usando um Estimator. Vamos configurar o input_fns para ler o conjunto de dados no modelo

def preprocessed_input_fn(for_eval):
  data = THE_DATASET.get_data(for_eval=for_eval)
  data = data.map(THE_DATASET.example_fn, num_parallel_calls=1)
  return data


def input_fn_train(params):
  data = preprocessed_input_fn(for_eval=False)
  data = data.repeat(None)
  data = data.shuffle(1024)
  data = data.batch(batch_size=params['batch_size'])
  return data


def input_fn_eval(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.repeat(1)
  data = data.batch(batch_size=params['batch_size'])
  return data


def input_fn_predict(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.batch(batch_size=params['batch_size'])
  return data

Vamos construir um modelo que use os embeddings CORD-19 com uma camada de classificação no topo.

def model_fn(features, labels, mode, params):
  # Embed the text
  embed = hub.Module(params['module_name'], trainable=params['trainable_module'])
  embeddings = embed(features['feature'])

  # Add a linear layer on top
  logits = tf.layers.dense(
      embeddings, units=THE_DATASET.num_classes(), activation=None)
  predictions = tf.argmax(input=logits, axis=1)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions={
            'logits': logits,
            'predictions': predictions,
            'features': features['feature'],
            'labels': features['label']
        })
  
  # Set up a multi-class classification head
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits)
  loss = tf.reduce_mean(loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=params['learning_rate'])
    train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

  elif mode == tf.estimator.ModeKeys.EVAL:
    accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
    precision = tf.metrics.precision(labels=labels, predictions=predictions)
    recall = tf.metrics.recall(labels=labels, predictions=predictions)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        eval_metric_ops={
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
        })



EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/1'  
TRAINABLE_MODULE = False  
STEPS =   8000
EVAL_EVERY = 200  
BATCH_SIZE = 10  
LEARNING_RATE = 0.01  

params = {
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'module_name': EMBEDDING,
    'trainable_module': TRAINABLE_MODULE
}

Treine e avalie o modelo

Vamos treinar e avaliar o modelo para ver o desempenho na tarefa SciCite

estimator = tf.estimator.Estimator(functools.partial(model_fn, params=params))
metrics = []

for step in range(0, STEPS, EVAL_EVERY):
  estimator.train(input_fn=functools.partial(input_fn_train, params=params), steps=EVAL_EVERY)
  step_metrics = estimator.evaluate(input_fn=functools.partial(input_fn_eval, params=params))
  print('Global step {}: loss {:.3f}, accuracy {:.3f}'.format(step, step_metrics['loss'], step_metrics['accuracy']))
  metrics.append(step_metrics)
Global step 0: loss 0.858, accuracy 0.622
Global step 200: loss 0.757, accuracy 0.676
Global step 400: loss 0.697, accuracy 0.713
Global step 600: loss 0.684, accuracy 0.711
Global step 800: loss 0.653, accuracy 0.737
Global step 1000: loss 0.621, accuracy 0.751
Global step 1200: loss 0.616, accuracy 0.754
Global step 1400: loss 0.599, accuracy 0.760
Global step 1600: loss 0.586, accuracy 0.769
Global step 1800: loss 0.588, accuracy 0.768
Global step 2000: loss 0.581, accuracy 0.765
Global step 2200: loss 0.574, accuracy 0.772
Global step 2400: loss 0.587, accuracy 0.768
Global step 2600: loss 0.581, accuracy 0.768
Global step 2800: loss 0.579, accuracy 0.770
Global step 3000: loss 0.569, accuracy 0.772
Global step 3200: loss 0.572, accuracy 0.774
Global step 3400: loss 0.566, accuracy 0.775
Global step 3600: loss 0.571, accuracy 0.769
Global step 3800: loss 0.555, accuracy 0.779
Global step 4000: loss 0.557, accuracy 0.774
Global step 4200: loss 0.548, accuracy 0.787
Global step 4400: loss 0.551, accuracy 0.782
Global step 4600: loss 0.558, accuracy 0.776
Global step 4800: loss 0.554, accuracy 0.781
Global step 5000: loss 0.548, accuracy 0.784
Global step 5200: loss 0.550, accuracy 0.780
Global step 5400: loss 0.552, accuracy 0.782
Global step 5600: loss 0.545, accuracy 0.784
Global step 5800: loss 0.542, accuracy 0.791
Global step 6000: loss 0.540, accuracy 0.791
Global step 6200: loss 0.542, accuracy 0.789
Global step 6400: loss 0.542, accuracy 0.789
Global step 6600: loss 0.542, accuracy 0.791
Global step 6800: loss 0.545, accuracy 0.790
Global step 7000: loss 0.540, accuracy 0.792
Global step 7200: loss 0.542, accuracy 0.789
Global step 7400: loss 0.540, accuracy 0.790
Global step 7600: loss 0.537, accuracy 0.792
Global step 7800: loss 0.538, accuracy 0.793

global_steps = [x['global_step'] for x in metrics]
fig, axes = plt.subplots(ncols=2, figsize=(20,8))

for axes_index, metric_names in enumerate([['accuracy', 'precision', 'recall'],
                                            ['loss']]):
  for metric_name in metric_names:
    axes[axes_index].plot(global_steps, [x[metric_name] for x in metrics], label=metric_name)
  axes[axes_index].legend()
  axes[axes_index].set_xlabel("Global Step")

png

Podemos ver que a perda diminui rapidamente enquanto, especialmente, a precisão aumenta rapidamente. Vamos plotar alguns exemplos para verificar como a previsão se relaciona com os rótulos verdadeiros:

predictions = estimator.predict(functools.partial(input_fn_predict, params))
first_10_predictions = list(itertools.islice(predictions, 10))

display_df(
  pd.DataFrame({
      TEXT_FEATURE_NAME: [pred['features'].decode('utf8') for pred in first_10_predictions],
      LABEL_NAME: [THE_DATASET.class_names()[pred['labels']] for pred in first_10_predictions],
      'prediction': [THE_DATASET.class_names()[pred['predictions']] for pred in first_10_predictions]
  }))

Podemos ver que, para essa amostra aleatória, o modelo prevê o rótulo correto na maioria das vezes, indicando que ele pode incorporar frases científicas muito bem.

Qual é o próximo?

Agora que você aprendeu um pouco mais sobre os embeddings CORD-19 Swivel do TF-Hub, incentivamos você a participar da competição CORD-19 Kaggle para contribuir com a obtenção de conhecimentos científicos de textos acadêmicos relacionados ao COVID-19.