質問があります? TensorFlowフォーラム訪問フォーラムでコミュニティとつながる

TFハブCORD-19スイベル埋め込みの探索

TensorFlow.orgで表示 GoogleColabで実行 GitHubで表示ノートブックをダウンロードするTFハブモデルを参照してください

TF-HubのCORD-19Swivelテキスト埋め込みモジュール(https://tfhub.dev/tensorflow/cord-19/swivel-128d/3)は、関連する自然言語テキストを分析する研究者をサポートするために構築されました。 COVID-19(新型コロナウイルス感染症。これらの埋め込みは、 CORD-19データセット内の記事のタイトル、著者、要約、本文、および参照タイトルについてトレーニングされました。

このコラボでは、次のことを行います。

  • 埋め込みスペースで意味的に類似した単語を分析する
  • CORD-19埋め込みを使用して、SciCiteデータセットで分類器をトレーニングします

セットアップ

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange

埋め込みを分析する

異なる項間の相関行列を計算してプロットすることにより、埋め込みを分析することから始めましょう。埋め込みが異なる単語の意味をうまく捉えることを学んだ場合、意味的に類似した単語の埋め込みベクトルは互いに接近している必要があります。 COVID-19関連の用語をいくつか見てみましょう。

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

png

埋め込みがさまざまな用語の意味をうまく捉えていることがわかります。各単語は、そのクラスターの他の単語と類似しています(つまり、「コロナウイルス」は「SARS」と「MERS」と高度に相関しています)が、他のクラスターの用語とは異なります(つまり、「SARS」と「スペイン」の類似性は0に近い)。

次に、これらの埋め込みを使用して特定のタスクを解決する方法を見てみましょう。

SciCite:引用意図の分類

このセクションでは、テキスト分類などのダウンストリームタスクに埋め込みを使用する方法を示します。 TensorFlowデータセットSciCiteデータセットを使用して、学術論文の引用インテントを分類します。学術論文からの引用を含む文を前提として、引用の主な目的が背景情報、方法の使用、または結果の比較のいずれであるかを分類します。

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

トレーニングセットからのいくつかのラベル付きの例を見てみましょう

NUM_EXAMPLES =   10

TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]

def label2str(numeric_label):
  m = builder.info.features[LABEL_NAME].names
  return m[numeric_label]

data = next(iter(train_data.batch(NUM_EXAMPLES)))


pd.DataFrame({
    TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
    LABEL_NAME: [label2str(x) for x in data[1]]
})

シタトンインテント分類器のトレーニング

Kerasを使用してSciCiteデータセットで分類器をトレーニングします。分類レイヤーを上にしたCORD-19埋め込みを使用するモデルを作成しましょう。

ハイパーパラメータ

EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3' 
TRAINABLE_MODULE = False 

hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[], 
                           dtype=tf.string, trainable=TRAINABLE_MODULE)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7fe7fc16cf28> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7fe7fc16cf28> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 128)               17301632  
_________________________________________________________________
dense (Dense)                (None, 3)                 387       
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

モデルのトレーニングと評価

モデルをトレーニングして評価し、SciCiteタスクのパフォーマンスを確認しましょう。

EPOCHS =   35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 1s 5ms/step - loss: 0.8978 - accuracy: 0.6190 - val_loss: 0.7777 - val_accuracy: 0.6736
Epoch 2/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6920 - accuracy: 0.7241 - val_loss: 0.6764 - val_accuracy: 0.7227
Epoch 3/35
257/257 [==============================] - 1s 5ms/step - loss: 0.6215 - accuracy: 0.7588 - val_loss: 0.6296 - val_accuracy: 0.7434
Epoch 4/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5867 - accuracy: 0.7757 - val_loss: 0.6071 - val_accuracy: 0.7489
Epoch 5/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5674 - accuracy: 0.7808 - val_loss: 0.5923 - val_accuracy: 0.7576
Epoch 6/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5549 - accuracy: 0.7874 - val_loss: 0.5818 - val_accuracy: 0.7664
Epoch 7/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5457 - accuracy: 0.7890 - val_loss: 0.5756 - val_accuracy: 0.7664
Epoch 8/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5392 - accuracy: 0.7883 - val_loss: 0.5698 - val_accuracy: 0.7740
Epoch 9/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5335 - accuracy: 0.7907 - val_loss: 0.5643 - val_accuracy: 0.7707
Epoch 10/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5294 - accuracy: 0.7934 - val_loss: 0.5609 - val_accuracy: 0.7751
Epoch 11/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5260 - accuracy: 0.7933 - val_loss: 0.5591 - val_accuracy: 0.7762
Epoch 12/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5231 - accuracy: 0.7929 - val_loss: 0.5578 - val_accuracy: 0.7773
Epoch 13/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5202 - accuracy: 0.7933 - val_loss: 0.5586 - val_accuracy: 0.7751
Epoch 14/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5181 - accuracy: 0.7938 - val_loss: 0.5536 - val_accuracy: 0.7806
Epoch 15/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5164 - accuracy: 0.7944 - val_loss: 0.5514 - val_accuracy: 0.7806
Epoch 16/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5144 - accuracy: 0.7961 - val_loss: 0.5513 - val_accuracy: 0.7751
Epoch 17/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5128 - accuracy: 0.7967 - val_loss: 0.5526 - val_accuracy: 0.7773
Epoch 18/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5113 - accuracy: 0.7974 - val_loss: 0.5524 - val_accuracy: 0.7784
Epoch 19/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5104 - accuracy: 0.7972 - val_loss: 0.5480 - val_accuracy: 0.7817
Epoch 20/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5094 - accuracy: 0.7977 - val_loss: 0.5479 - val_accuracy: 0.7817
Epoch 21/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5077 - accuracy: 0.7981 - val_loss: 0.5480 - val_accuracy: 0.7838
Epoch 22/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5072 - accuracy: 0.7983 - val_loss: 0.5450 - val_accuracy: 0.7817
Epoch 23/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5059 - accuracy: 0.7994 - val_loss: 0.5495 - val_accuracy: 0.7795
Epoch 24/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5056 - accuracy: 0.7991 - val_loss: 0.5471 - val_accuracy: 0.7817
Epoch 25/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5048 - accuracy: 0.7983 - val_loss: 0.5474 - val_accuracy: 0.7806
Epoch 26/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5042 - accuracy: 0.8000 - val_loss: 0.5481 - val_accuracy: 0.7828
Epoch 27/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5031 - accuracy: 0.7996 - val_loss: 0.5447 - val_accuracy: 0.7882
Epoch 28/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5027 - accuracy: 0.8007 - val_loss: 0.5440 - val_accuracy: 0.7860
Epoch 29/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5018 - accuracy: 0.8024 - val_loss: 0.5460 - val_accuracy: 0.7871
Epoch 30/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5017 - accuracy: 0.8002 - val_loss: 0.5454 - val_accuracy: 0.7893
Epoch 31/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5013 - accuracy: 0.8008 - val_loss: 0.5456 - val_accuracy: 0.7828
Epoch 32/35
257/257 [==============================] - 1s 5ms/step - loss: 0.5003 - accuracy: 0.8024 - val_loss: 0.5467 - val_accuracy: 0.7871
Epoch 33/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4999 - accuracy: 0.8007 - val_loss: 0.5445 - val_accuracy: 0.7915
Epoch 34/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4999 - accuracy: 0.8025 - val_loss: 0.5445 - val_accuracy: 0.7893
Epoch 35/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4995 - accuracy: 0.8010 - val_loss: 0.5453 - val_accuracy: 0.7828
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

png

モデルを評価する

そして、モデルがどのように機能するかを見てみましょう。 2つの値が返されます。損失(エラーを表す数値、値が小さいほど良い)、および精度。

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5357 - accuracy: 0.7918
loss: 0.536
accuracy: 0.792

特に精度が急激に向上する一方で、損失が急速に減少することがわかります。いくつかの例をプロットして、予測が実際のラベルにどのように関連しているかを確認しましょう。

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [label2str(x) for x in model.predict_classes(prediction_texts)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).
WARNING:tensorflow:From <ipython-input-1-0e10f5eff104>:6: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

このランダムサンプルの場合、モデルはほとんどの場合正しいラベルを予測し、科学的な文をかなりうまく埋め込むことができることを示していることがわかります。

次は何ですか?

TF-HubからのCORD-19Swivel埋め込みについてもう少し詳しく知ったので、CORD-19 Kaggleコンテストに参加して、COVID-19関連の学術テキストから科学的洞察を得ることに貢献することをお勧めします。