最近傍とテキスト埋め込みによるセマンティック検索

TensorFlow.org で表示 Google Colabで実行 GitHub でソースを表示 ノートブックをダウンロード TF Hub モデルを参照

このチュートリアルでは、TensorFlow Hub(TF-Hub)が提供する入力データから埋め込みを生成し、抽出された埋め込みを使用して最近傍(ANN)インデックスを構築する方法を説明します。構築されたインデックスは、リアルタイムに類似性の一致と検索を行うために使用できます。

大規模なコーパスのデータを取り扱う場合、特定のクエリに対して最も類似するアイテムをリアルタイムで見つけるために、レポジトリ全体をスキャンして完全一致を行うというのは、効率的ではありません。そのため、おおよその類似性一致アルゴリズムを使用することで、正確な最近傍の一致を見つける際の精度を少しだけ犠牲にし、速度を大幅に向上させることができます。

このチュートリアルでは、ニュースの見出しのコーパスに対してリアルタイムテキスト検索を行い、クエリに最も類似する見出しを見つけ出す例を示します。この検索はキーワード検索とは異なり、テキスト埋め込みにエンコードされた意味的類似性をキャプチャします。

このチュートリアルの手順は次のとおりです。

  1. サンプルデータをダウンロードする。
  2. TF-Hub モジュールを使用して、データの埋め込みを生成する。
  3. 埋め込みの ANN インデックスを構築する。
  4. インデックスを使って、類似性の一致を実施する。

TF-Hub モデルから埋め込みを生成するには、Apache Beam を使用します。また、最近傍インデックスの構築には、Spotify の ANNOY ライブラリを使用します。

その他のモデル

アーキテクチャは同じであっても異なる言語でトレーニングされたモデルについては、こちらのコレクションを参照してください。こちらでは、現在 tfhub.dev にホストされているすべてのテキスト埋め込みを検索できます。

設定

必要なライブラリをインストールします。

pip install apache_beam
pip install 'scikit_learn~=0.23.0'  # For gaussian_random_matrix.
pip install annoy

必要なライブラリをインポートします。

import os
import sys
import pickle
from collections import namedtuple
from datetime import datetime
import numpy as np
import apache_beam as beam
from apache_beam.transforms import util
import tensorflow as tf
import tensorflow_hub as hub
import annoy
from sklearn.random_projection import gaussian_random_matrix
2022-12-14 20:43:32.239690: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:43:32.239789: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:43:32.239797: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
print('TF version: {}'.format(tf.__version__))
print('TF-Hub version: {}'.format(hub.__version__))
print('Apache Beam version: {}'.format(beam.__version__))
TF version: 2.11.0
TF-Hub version: 0.12.0
Apache Beam version: 2.43.0

1. サンプルデータをダウンロードする

A Million News Headlines データセットには、15 年にわたって発行されたニュースの見出しが含まれます。出典は、有名なオーストラリア放送協会(ABC)です。このニュースデータセットは、2003 年の始めから 2017 年の終わりまでの特筆すべき世界的なイベントについて、オーストラリアにより焦点を当てた記録が含まれます。

形式: 1)発行日と 2)見出しのテキストの 2 列をタブ区切りにしたデータ。このチュートリアルで関心があるのは、見出しのテキストのみです。

wget 'https://dataverse.harvard.edu/api/access/datafile/3450625?format=tab&gbrecs=true' -O raw.tsv
wc -l raw.tsv
head raw.tsv
--2022-12-14 20:43:33--  https://dataverse.harvard.edu/api/access/datafile/3450625?format=tab&gbrecs=true
Resolving dataverse.harvard.edu (dataverse.harvard.edu)... 44.213.44.146, 52.23.87.139, 52.54.15.150
Connecting to dataverse.harvard.edu (dataverse.harvard.edu)|44.213.44.146|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 57600231 (55M) [text/tab-separated-values]
Saving to: ‘raw.tsv’

raw.tsv             100%[===================>]  54.93M  75.1MB/s    in 0.7s    

2022-12-14 20:43:34 (75.1 MB/s) - ‘raw.tsv’ saved [57600231/57600231]

1103664 raw.tsv
publish_date    headline_text
20030219    "aba decides against community broadcasting licence"
20030219    "act fire witnesses must be aware of defamation"
20030219    "a g calls for infrastructure protection summit"
20030219    "air nz staff in aust strike for pay rise"
20030219    "air nz strike to affect australian travellers"
20030219    "ambitious olsson wins triple jump"
20030219    "antic delighted with record breaking barca"
20030219    "aussie qualifier stosur wastes four memphis match"
20030219    "aust addresses un security council over iraq"

単純化するため、見出しのテキストのみを維持し、発行日は削除します。

!rm -r corpus
!mkdir corpus

with open('corpus/text.txt', 'w') as out_file:
  with open('raw.tsv', 'r') as in_file:
    for line in in_file:
      headline = line.split('\t')[1].strip().strip('"')
      out_file.write(headline+"\n")
rm: cannot remove 'corpus': No such file or directory
tail corpus/text.txt
severe storms forecast for nye in south east queensland
snake catcher pleads for people not to kill reptiles
south australia prepares for party to welcome new year
strikers cool off the heat with big win in adelaide
stunning images from the sydney to hobart yacht
the ashes smiths warners near miss liven up boxing day test
timelapse: brisbanes new year fireworks
what 2017 meant to the kids of australia
what the papodopoulos meeting may mean for ausus
who is george papadopoulos the former trump campaign aide

2. データの埋め込みを生成する

このチュートリアルでは、ニューラルネットワーク言語モデル(NNLM)を使用して、見出しデータの埋め込みを生成します。その後で、文章レベルの意味の類似性を計算するために、文章埋め込みを簡単に使用することが可能となります。埋め込み生成プロセスは、Apache Beam を使用して実行します。

埋め込み抽出メソッド

embed_fn = None

def generate_embeddings(text, model_url, random_projection_matrix=None):
  # Beam will run this function in different processes that need to
  # import hub and load embed_fn (if not previously loaded)
  global embed_fn
  if embed_fn is None:
    embed_fn = hub.load(model_url)
  embedding = embed_fn(text).numpy()
  if random_projection_matrix is not None:
    embedding = embedding.dot(random_projection_matrix)
  return text, embedding

tf.Example メソッドへの変換

def to_tf_example(entries):
  examples = []

  text_list, embedding_list = entries
  for i in range(len(text_list)):
    text = text_list[i]
    embedding = embedding_list[i]

    features = {
        'text': tf.train.Feature(
            bytes_list=tf.train.BytesList(value=[text.encode('utf-8')])),
        'embedding': tf.train.Feature(
            float_list=tf.train.FloatList(value=embedding.tolist()))
    }

    example = tf.train.Example(
        features=tf.train.Features(
            feature=features)).SerializeToString(deterministic=True)

    examples.append(example)

  return examples

Beam パイプライン

def run_hub2emb(args):
  '''Runs the embedding generation pipeline'''

  options = beam.options.pipeline_options.PipelineOptions(**args)
  args = namedtuple("options", args.keys())(*args.values())

  with beam.Pipeline(args.runner, options=options) as pipeline:
    (
        pipeline
        | 'Read sentences from files' >> beam.io.ReadFromText(
            file_pattern=args.data_dir)
        | 'Batch elements' >> util.BatchElements(
            min_batch_size=args.batch_size, max_batch_size=args.batch_size)
        | 'Generate embeddings' >> beam.Map(
            generate_embeddings, args.model_url, args.random_projection_matrix)
        | 'Encode to tf example' >> beam.FlatMap(to_tf_example)
        | 'Write to TFRecords files' >> beam.io.WriteToTFRecord(
            file_path_prefix='{}/emb'.format(args.output_dir),
            file_name_suffix='.tfrecords')
    )

ランダムプロジェクションの重み行列を生成する

ランダムプロジェクションは、ユークリッド空間に存在する一連の点の次元を縮小するために使用される、単純でありながら高性能のテクニックです。理論的背景については、Johnson-Lindenstrauss の補題をご覧ください。

ランダムプロジェクションを使用して埋め込みの次元を縮小するということは、ANN インデックスの構築とクエリに必要となる時間を短縮できるということです。

このチュートリアルでは、Scikit-learn ライブラリのガウスランダムプロジェクションを使用します。

def generate_random_projection_weights(original_dim, projected_dim):
  random_projection_matrix = None
  random_projection_matrix = gaussian_random_matrix(
      n_components=projected_dim, n_features=original_dim).T
  print("A Gaussian random weight matrix was creates with shape of {}".format(random_projection_matrix.shape))
  print('Storing random projection matrix to disk...')
  with open('random_projection_matrix', 'wb') as handle:
    pickle.dump(random_projection_matrix, 
                handle, protocol=pickle.HIGHEST_PROTOCOL)

  return random_projection_matrix

パラメータの設定

ランダムプロジェクションを使用せずに、元の埋め込み空間を使用してインデックスを構築する場合は、projected_dim パラメータを None に設定します。これにより、高次元埋め込みのインデックス作成ステップが減速することに注意してください。

パイプラインの実行

import tempfile

output_dir = tempfile.mkdtemp()
original_dim = hub.load(model_url)(['']).shape[1]
random_projection_matrix = None

if projected_dim:
  random_projection_matrix = generate_random_projection_weights(
      original_dim, projected_dim)

args = {
    'job_name': 'hub2emb-{}'.format(datetime.utcnow().strftime('%y%m%d-%H%M%S')),
    'runner': 'DirectRunner',
    'batch_size': 1024,
    'data_dir': 'corpus/*.txt',
    'output_dir': output_dir,
    'model_url': model_url,
    'random_projection_matrix': random_projection_matrix,
}

print("Pipeline args are set.")
args
A Gaussian random weight matrix was creates with shape of (128, 64)
Storing random projection matrix to disk...
Pipeline args are set.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:86: FutureWarning: Function gaussian_random_matrix is deprecated; gaussian_random_matrix is deprecated in 0.22 and will be removed in version 0.24.
  warnings.warn(msg, category=FutureWarning)
{'job_name': 'hub2emb-221214-204344',
 'runner': 'DirectRunner',
 'batch_size': 1024,
 'data_dir': 'corpus/*.txt',
 'output_dir': '/tmpfs/tmp/tmp416d2_rg',
 'model_url': 'https://tfhub.dev/google/nnlm-en-dim128/2',
 'random_projection_matrix': array([[-0.03414318,  0.16726298, -0.00830948, ...,  0.14487795,
          0.09178456, -0.02400007],
        [ 0.0295565 , -0.2500408 ,  0.0449117 , ...,  0.00851299,
         -0.00206989,  0.01522875],
        [ 0.01691577,  0.04981695,  0.18249518, ...,  0.30985875,
         -0.11174661, -0.16737778],
        ...,
        [ 0.0669922 , -0.20554815,  0.10698919, ..., -0.15748863,
         -0.1854706 ,  0.29059386],
        [-0.07822566, -0.00124182, -0.16006914, ...,  0.18500444,
         -0.07912207, -0.02134945],
        [ 0.09786977,  0.00253127,  0.0866723 , ...,  0.10495674,
         -0.27430763,  0.14429715]])}
print("Running pipeline...")
%time run_hub2emb(args)
print("Pipeline is done.")
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
Running pipeline...
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['-f', '/tmpfs/tmp/tmpwn2jsjq9.json', '--HistoryManager.hist_file=:memory:']
WARNING:apache_beam.options.pipeline_options:Discarding invalid overrides: {'batch_size': 1024, 'data_dir': 'corpus/*.txt', 'output_dir': '/tmpfs/tmp/tmp416d2_rg', 'model_url': 'https://tfhub.dev/google/nnlm-en-dim128/2', 'random_projection_matrix': array([[-0.03414318,  0.16726298, -0.00830948, ...,  0.14487795,
         0.09178456, -0.02400007],
       [ 0.0295565 , -0.2500408 ,  0.0449117 , ...,  0.00851299,
        -0.00206989,  0.01522875],
       [ 0.01691577,  0.04981695,  0.18249518, ...,  0.30985875,
        -0.11174661, -0.16737778],
       ...,
       [ 0.0669922 , -0.20554815,  0.10698919, ..., -0.15748863,
        -0.1854706 ,  0.29059386],
       [-0.07822566, -0.00124182, -0.16006914, ...,  0.18500444,
        -0.07912207, -0.02134945],
       [ 0.09786977,  0.00253127,  0.0866723 , ...,  0.10495674,
        -0.27430763,  0.14429715]])}
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
CPU times: user 32min 6s, sys: 35min 39s, total: 1h 7min 45s
Wall time: 2min 35s
Pipeline is done.
ls {output_dir}
emb-00000-of-00001.tfrecords

生成された埋め込みをいくつか読み取ります。

embed_file = os.path.join(output_dir, 'emb-00000-of-00001.tfrecords')
sample = 5

# Create a description of the features.
feature_description = {
    'text': tf.io.FixedLenFeature([], tf.string),
    'embedding': tf.io.FixedLenFeature([projected_dim], tf.float32)
}

def _parse_example(example):
  # Parse the input `tf.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example, feature_description)

dataset = tf.data.TFRecordDataset(embed_file)
for record in dataset.take(sample).map(_parse_example):
  print("{}: {}".format(record['text'].numpy().decode('utf-8'), record['embedding'].numpy()[:10]))
headline_text: [-0.03838544 -0.04770774  0.27030632 -0.1190206   0.00493926  0.04791806
  0.07101743  0.1289899  -0.12889388 -0.05041549]
aba decides against community broadcasting licence: [ 0.06278004  0.06718227 -0.14733256  0.00860509  0.17096     0.01168742
 -0.03256782  0.05830066 -0.04695004  0.00944522]
act fire witnesses must be aware of defamation: [-0.01783587 -0.06580362 -0.06844856 -0.00639954  0.08169422  0.00793928
 -0.04115272 -0.08234683  0.05548596 -0.09776891]
a g calls for infrastructure protection summit: [-0.01583667 -0.05274586 -0.05948982  0.00670321 -0.00244301 -0.07551357
 -0.03519824  0.22936538 -0.10490779 -0.28691873]
air nz staff in aust strike for pay rise: [ 0.04377367  0.03826201  0.0856059   0.31610894 -0.0145447   0.02677747
  0.1210634  -0.0205738  -0.02379816 -0.2290168 ]

3. 埋め込みの ANN インデックスを構築する

ANNOY(Approximate Nearest Neighbors Oh Yeah)は、特定のクエリ点に近い空間内のポイントを検索するための、Python バインディングを使った C++ ライブラリです。メモリにマッピングされた、大規模な読み取り専用ファイルベースのデータ構造も作成します。Spotify が構築したもので、おすすめの音楽に使用されています。興味があれば、NGTFAISS などの ANNOY に代わるライブラリを使用してみてください。

def build_index(embedding_files_pattern, index_filename, vector_length, 
    metric='angular', num_trees=100):
  '''Builds an ANNOY index'''

  annoy_index = annoy.AnnoyIndex(vector_length, metric=metric)
  # Mapping between the item and its identifier in the index
  mapping = {}

  embed_files = tf.io.gfile.glob(embedding_files_pattern)
  num_files = len(embed_files)
  print('Found {} embedding file(s).'.format(num_files))

  item_counter = 0
  for i, embed_file in enumerate(embed_files):
    print('Loading embeddings in file {} of {}...'.format(i+1, num_files))
    dataset = tf.data.TFRecordDataset(embed_file)
    for record in dataset.map(_parse_example):
      text = record['text'].numpy().decode("utf-8")
      embedding = record['embedding'].numpy()
      mapping[item_counter] = text
      annoy_index.add_item(item_counter, embedding)
      item_counter += 1
      if item_counter % 100000 == 0:
        print('{} items loaded to the index'.format(item_counter))

  print('A total of {} items added to the index'.format(item_counter))

  print('Building the index with {} trees...'.format(num_trees))
  annoy_index.build(n_trees=num_trees)
  print('Index is successfully built.')

  print('Saving index to disk...')
  annoy_index.save(index_filename)
  print('Index is saved to disk.')
  print("Index file size: {} GB".format(
    round(os.path.getsize(index_filename) / float(1024 ** 3), 2)))
  annoy_index.unload()

  print('Saving mapping to disk...')
  with open(index_filename + '.mapping', 'wb') as handle:
    pickle.dump(mapping, handle, protocol=pickle.HIGHEST_PROTOCOL)
  print('Mapping is saved to disk.')
  print("Mapping file size: {} MB".format(
    round(os.path.getsize(index_filename + '.mapping') / float(1024 ** 2), 2)))
embedding_files = "{}/emb-*.tfrecords".format(output_dir)
embedding_dimension = projected_dim
index_filename = "index"

!rm {index_filename}
!rm {index_filename}.mapping

%time build_index(embedding_files, index_filename, embedding_dimension)
rm: cannot remove 'index': No such file or directory
rm: cannot remove 'index.mapping': No such file or directory
Found 1 embedding file(s).
Loading embeddings in file 1 of 1...
100000 items loaded to the index
200000 items loaded to the index
300000 items loaded to the index
400000 items loaded to the index
500000 items loaded to the index
600000 items loaded to the index
700000 items loaded to the index
800000 items loaded to the index
900000 items loaded to the index
1000000 items loaded to the index
1100000 items loaded to the index
A total of 1103664 items added to the index
Building the index with 100 trees...
Index is successfully built.
Saving index to disk...
Index is saved to disk.
Index file size: 1.6 GB
Saving mapping to disk...
Mapping is saved to disk.
Mapping file size: 50.61 MB
CPU times: user 9min 17s, sys: 52.2 s, total: 10min 9s
Wall time: 3min 39s
ls
corpus         random_projection_matrix
index          raw.tsv
index.mapping  tf2_semantic_approximate_nearest_neighbors.ipynb

4. インデックスを使って、類似性の一致を実施する

ANN インデックスを使用して、入力クエリに意味的に近いニュースの見出しを検索できるようになりました。

インデックスとマッピングファイルを読み込む

index = annoy.AnnoyIndex(embedding_dimension)
index.load(index_filename, prefault=True)
print('Annoy index is loaded.')
with open(index_filename + '.mapping', 'rb') as handle:
  mapping = pickle.load(handle)
print('Mapping file is loaded.')
Annoy index is loaded.
/tmpfs/tmp/ipykernel_42348/1659470767.py:1: FutureWarning: The default argument for metric will be removed in future version of Annoy. Please pass metric='angular' explicitly.
  index = annoy.AnnoyIndex(embedding_dimension)
Mapping file is loaded.

類似性の一致メソッド

def find_similar_items(embedding, num_matches=5):
  '''Finds similar items to a given embedding in the ANN index'''
  ids = index.get_nns_by_vector(
  embedding, num_matches, search_k=-1, include_distances=False)
  items = [mapping[i] for i in ids]
  return items

特定のクエリから埋め込みを抽出する

# Load the TF-Hub model
print("Loading the TF-Hub model...")
%time embed_fn = hub.load(model_url)
print("TF-Hub model is loaded.")

random_projection_matrix = None
if os.path.exists('random_projection_matrix'):
  print("Loading random projection matrix...")
  with open('random_projection_matrix', 'rb') as handle:
    random_projection_matrix = pickle.load(handle)
  print('random projection matrix is loaded.')

def extract_embeddings(query):
  '''Generates the embedding for the query'''
  query_embedding =  embed_fn([query])[0].numpy()
  if random_projection_matrix is not None:
    query_embedding = query_embedding.dot(random_projection_matrix)
  return query_embedding
Loading the TF-Hub model...
CPU times: user 468 ms, sys: 391 ms, total: 859 ms
Wall time: 888 ms
TF-Hub model is loaded.
Loading random projection matrix...
random projection matrix is loaded.
extract_embeddings("Hello Machine Learning!")[:10]
array([-0.06940051, -0.02383568, -0.1976899 , -0.06788831, -0.00921661,
        0.03445504,  0.0885061 ,  0.03274157, -0.13551793,  0.03087829])

クエリを入力して、類似性の最も高いアイテムを検索する

Generating embedding for the query...
CPU times: user 4.61 ms, sys: 513 µs, total: 5.12 ms
Wall time: 2.34 ms

Finding relevant items in the index...
CPU times: user 239 µs, sys: 211 µs, total: 450 µs
Wall time: 461 µs

Results:
=========
confronting global challenges
world aids day highlights global challenge
advertising faces modern challenges
global response
conference examines challenges facing major cities
national approach sought to combat domestic
bluescope ponders global challenges
asian giants unite to tackle global crisis
old wisdom unites to solve global dilemmas
old wisdom unites to solve global dilemmas

今後の学習

tensorflow.org/hub では、TensorFlow についてさらに学習し、TF-Hub API ドキュメントを確認することができます。また、tfhub.dev では、その他のテキスト埋め込みモデルや画像特徴量ベクトルモデルなど、利用可能な TensorFlow Hub モジュールを検索することができます。

さらに、Google の Machine Learning Crash Course もご覧ください。機械学習の実用的な導入をテンポよく学習できます。