Salva la data! Google I / O ritorna dal 18 al 20 maggio Registrati ora
Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Esplorando il TF-Hub CORD-19 Swivel Embedding

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza su GitHub Scarica taccuino Vedi modello TF Hub

Il modulo di incorporamento del testo girevole CORD-19 da TF-Hub (https: //tfhub.dev/tensorflow/cord-19/swivel-128d/1) è stato creato per supportare i ricercatori che analizzano il testo in lingue naturali correlato a COVID-19. Questi incorporamenti sono stati formati su titoli, autori, abstract, testi del corpo e titoli di riferimento di articoli nel set di dati CORD-19 .

In questa colab faremo:

  • Analizza parole semanticamente simili nello spazio di incorporamento
  • Addestra un classificatore sul set di dati SciCite utilizzando gli incorporamenti CORD-19

Impostare

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.logging.set_verbosity('ERROR')

import tensorflow_datasets as tfds
import tensorflow_hub as hub

try:
  from google.colab import data_table
  def display_df(df):
    return data_table.DataTable(df, include_index=False)
except ModuleNotFoundError:
  # If google-colab is not available, just display the raw DataFrame
  def display_df(df):
    return df

Analizza gli incorporamenti

Cominciamo analizzando l'incorporamento calcolando e tracciando una matrice di correlazione tra termini diversi. Se l'incorporamento ha imparato a catturare con successo il significato di parole diverse, i vettori di incorporamento di parole semanticamente simili dovrebbero essere vicini tra loro. Diamo un'occhiata ad alcuni termini correlati a COVID-19.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)


with tf.Graph().as_default():
  # Load the module
  query_input = tf.placeholder(tf.string)
  module = hub.Module('https://tfhub.dev/tensorflow/cord-19/swivel-128d/1')
  embeddings = module(query_input)

  with tf.train.MonitoredTrainingSession() as sess:

    # Generate embeddings for some terms
    queries = [
        # Related viruses
        "coronavirus", "SARS", "MERS",
        # Regions
        "Italy", "Spain", "Europe",
        # Symptoms
        "cough", "fever", "throat"
    ]

    features = sess.run(embeddings, feed_dict={query_input: queries})
    plot_correlation(queries, features)

png

Possiamo vedere che l'incorporamento ha catturato con successo il significato dei diversi termini. Ogni parola è simile alle altre parole del suo cluster (cioè "coronavirus" è altamente correlato con "SARS" e "MERS"), mentre sono diverse dai termini di altri cluster (cioè la somiglianza tra "SARS" e "Spagna" vicino a 0).

Ora vediamo come possiamo utilizzare questi incorporamenti per risolvere un'attività specifica.

SciCite: Citation Intent Classification

Questa sezione mostra come utilizzare l'incorporamento per attività a valle come la classificazione del testo. Useremo il set di dati SciCite dei set di dati TensorFlow per classificare gli intenti di citazione nei documenti accademici. Data una frase con una citazione da un articolo accademico, classificare se l'intento principale della citazione è come informazione di base, uso di metodi o confronto dei risultati.

Imposta il set di dati da TFDS

Downloading and preparing dataset scicite/1.0.0 (download: 22.12 MiB, generated: Unknown size, total: 22.12 MiB) to /home/kbuilder/tensorflow_datasets/scicite/1.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteHWK5SE/scicite-train.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteHWK5SE/scicite-validation.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/scicite/1.0.0.incompleteHWK5SE/scicite-test.tfrecord
Dataset scicite downloaded and prepared to /home/kbuilder/tensorflow_datasets/scicite/1.0.0. Subsequent calls will reuse this data.

Diamo un'occhiata ad alcuni esempi etichettati dal set di formazione

Formazione di un classificatore di intenti citaton

Addestreremo un classificatore sul set di dati SciCite utilizzando uno stimatore. Impostiamo input_fns per leggere il set di dati nel modello

def preprocessed_input_fn(for_eval):
  data = THE_DATASET.get_data(for_eval=for_eval)
  data = data.map(THE_DATASET.example_fn, num_parallel_calls=1)
  return data


def input_fn_train(params):
  data = preprocessed_input_fn(for_eval=False)
  data = data.repeat(None)
  data = data.shuffle(1024)
  data = data.batch(batch_size=params['batch_size'])
  return data


def input_fn_eval(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.repeat(1)
  data = data.batch(batch_size=params['batch_size'])
  return data


def input_fn_predict(params):
  data = preprocessed_input_fn(for_eval=True)
  data = data.batch(batch_size=params['batch_size'])
  return data

Costruiamo un modello che utilizzi gli incorporamenti CORD-19 con uno strato di classificazione in cima.

def model_fn(features, labels, mode, params):
  # Embed the text
  embed = hub.Module(params['module_name'], trainable=params['trainable_module'])
  embeddings = embed(features['feature'])

  # Add a linear layer on top
  logits = tf.layers.dense(
      embeddings, units=THE_DATASET.num_classes(), activation=None)
  predictions = tf.argmax(input=logits, axis=1)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions={
            'logits': logits,
            'predictions': predictions,
            'features': features['feature'],
            'labels': features['label']
        })

  # Set up a multi-class classification head
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits)
  loss = tf.reduce_mean(loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=params['learning_rate'])
    train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

  elif mode == tf.estimator.ModeKeys.EVAL:
    accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
    precision = tf.metrics.precision(labels=labels, predictions=predictions)
    recall = tf.metrics.recall(labels=labels, predictions=predictions)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        eval_metric_ops={
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
        })

Hyperparmeters

Addestra e valuta il modello

Addestriamo e valutiamo il modello per vedere le prestazioni dell'attività SciCite

estimator = tf.estimator.Estimator(functools.partial(model_fn, params=params))
metrics = []

for step in range(0, STEPS, EVAL_EVERY):
  estimator.train(input_fn=functools.partial(input_fn_train, params=params), steps=EVAL_EVERY)
  step_metrics = estimator.evaluate(input_fn=functools.partial(input_fn_eval, params=params))
  print('Global step {}: loss {:.3f}, accuracy {:.3f}'.format(step, step_metrics['loss'], step_metrics['accuracy']))
  metrics.append(step_metrics)
Global step 0: loss 0.796, accuracy 0.670
Global step 200: loss 0.701, accuracy 0.732
Global step 400: loss 0.682, accuracy 0.719
Global step 600: loss 0.650, accuracy 0.747
Global step 800: loss 0.620, accuracy 0.762
Global step 1000: loss 0.609, accuracy 0.762
Global step 1200: loss 0.605, accuracy 0.762
Global step 1400: loss 0.585, accuracy 0.783
Global step 1600: loss 0.586, accuracy 0.768
Global step 1800: loss 0.577, accuracy 0.774
Global step 2000: loss 0.584, accuracy 0.765
Global step 2200: loss 0.565, accuracy 0.778
Global step 2400: loss 0.570, accuracy 0.776
Global step 2600: loss 0.556, accuracy 0.789
Global step 2800: loss 0.563, accuracy 0.778
Global step 3000: loss 0.557, accuracy 0.784
Global step 3200: loss 0.566, accuracy 0.774
Global step 3400: loss 0.552, accuracy 0.782
Global step 3600: loss 0.551, accuracy 0.785
Global step 3800: loss 0.547, accuracy 0.788
Global step 4000: loss 0.549, accuracy 0.784
Global step 4200: loss 0.548, accuracy 0.785
Global step 4400: loss 0.553, accuracy 0.783
Global step 4600: loss 0.543, accuracy 0.786
Global step 4800: loss 0.548, accuracy 0.783
Global step 5000: loss 0.547, accuracy 0.785
Global step 5200: loss 0.539, accuracy 0.791
Global step 5400: loss 0.546, accuracy 0.782
Global step 5600: loss 0.548, accuracy 0.781
Global step 5800: loss 0.540, accuracy 0.791
Global step 6000: loss 0.542, accuracy 0.790
Global step 6200: loss 0.539, accuracy 0.792
Global step 6400: loss 0.545, accuracy 0.788
Global step 6600: loss 0.552, accuracy 0.781
Global step 6800: loss 0.549, accuracy 0.783
Global step 7000: loss 0.540, accuracy 0.788
Global step 7200: loss 0.543, accuracy 0.782
Global step 7400: loss 0.541, accuracy 0.787
Global step 7600: loss 0.532, accuracy 0.790
Global step 7800: loss 0.537, accuracy 0.792
global_steps = [x['global_step'] for x in metrics]
fig, axes = plt.subplots(ncols=2, figsize=(20,8))

for axes_index, metric_names in enumerate([['accuracy', 'precision', 'recall'],
                                            ['loss']]):
  for metric_name in metric_names:
    axes[axes_index].plot(global_steps, [x[metric_name] for x in metrics], label=metric_name)
  axes[axes_index].legend()
  axes[axes_index].set_xlabel("Global Step")

png

Possiamo vedere che la perdita diminuisce rapidamente mentre, soprattutto, la precisione aumenta rapidamente. Tracciamo alcuni esempi per verificare come la previsione si riferisce alle vere etichette:

predictions = estimator.predict(functools.partial(input_fn_predict, params))
first_10_predictions = list(itertools.islice(predictions, 10))

display_df(
  pd.DataFrame({
      TEXT_FEATURE_NAME: [pred['features'].decode('utf8') for pred in first_10_predictions],
      LABEL_NAME: [THE_DATASET.class_names()[pred['labels']] for pred in first_10_predictions],
      'prediction': [THE_DATASET.class_names()[pred['predictions']] for pred in first_10_predictions]
  }))

Possiamo vedere che per questo campione casuale, il modello predice l'etichetta corretta la maggior parte delle volte, indicando che può incorporare abbastanza bene frasi scientifiche.

Qual è il prossimo?

Ora che hai imparato un po 'di più sugli incorporamenti CORD-19 Swivel di TF-Hub, ti invitiamo a partecipare al concorso CORD-19 Kaggle per contribuire ad ottenere approfondimenti scientifici dai testi accademici relativi a COVID-19.