Approximate Nearest Neighbor(ANN) 및 텍스트 임베딩을 사용한 의미론적 검색

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 보기 노트북 다운로드 TF Hub 모델 보기

이 튜토리얼에서는 입력 데이터가 제공된 TensorFlow Hub(TF-Hub) 모듈에서 임베딩을 생성하고 추출된 임베딩을 사용하여 approximate nearest neighbour(ANN) 인덱스를 빌드하는 방법을 보여줍니다. 그런 다음 이 인덱스를 실시간 유사성 일치 및 검색에 사용할 수 있습니다.

많은 양의 데이터를 처리할 때 전체 리포지토리를 스캔하여 주어진 쿼리와 가장 유사한 항목을 실시간으로 찾는 식으로 정확한 일치 작업을 수행하는 것은 효율적이지 않습니다. 따라서 속도를 크게 높이기 위해 정확한 nearest neighbor(NN) 일치를 찾을 때 약간의 정확성을 절충할 수 있는 근사 유사성 일치 알고리즘을 사용합니다.

이 튜토리얼에서는 쿼리와 가장 유사한 헤드라인을 찾기 위해 뉴스 헤드라인 자료의 텍스트를 실시간으로 검색하는 예를 보여줍니다. 키워드 검색과 달리 이 검색으로 텍스트 임베딩에 인코딩된 의미론적 유사성이 포착됩니다.

이 튜토리얼의 단계는 다음과 같습니다.

  1. 샘플 데이터를 다운로드합니다.
  2. TF-Hub 모델을 사용하여 데이터에 대한 임베딩을 생성합니다.
  3. 임베딩에 대한 ANN 인덱스를 빌드합니다.
  4. 유사성 일치에 인덱스를 사용합니다.

Apache Beam을 사용하여 TF-Hub 모델에서 임베딩을 생성합니다. 또한 Spotify의 ANNOY 라이브러리를 사용하여 근사적으로 최근접 인덱스를 빌드합니다.

더 많은 모델

아키텍처는 동일하지만 다른 언어로 훈련된 모델의 경우, 컬렉션을 참조하세요. 여기에서 현재 tfhub.dev에서 호스팅되는 모든 텍스트 임베딩을 찾을 수 있습니다.

설정

필요한 라이브러리를 설치합니다.

pip install apache_beam
pip install 'scikit_learn~=0.23.0'  # For gaussian_random_matrix.
pip install annoy

필요한 라이브러리를 가져옵니다.

import os
import sys
import pickle
from collections import namedtuple
from datetime import datetime
import numpy as np
import apache_beam as beam
from apache_beam.transforms import util
import tensorflow as tf
import tensorflow_hub as hub
import annoy
from sklearn.random_projection import gaussian_random_matrix
2022-12-14 22:23:13.120555: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:23:13.120681: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:23:13.120692: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
print('TF version: {}'.format(tf.__version__))
print('TF-Hub version: {}'.format(hub.__version__))
print('Apache Beam version: {}'.format(beam.__version__))
TF version: 2.11.0
TF-Hub version: 0.12.0
Apache Beam version: 2.43.0

1. 샘플 데이터 다운로드

A Million News Headlines 데이터세트에는 평판이 좋은 Australian Broadcasting Corp. (ABC)에서 공급한 15년치의 뉴스 헤드라인이 수록되어 있습니다. 이 뉴스 데이터세트에는 호주에 보다 세분화된 초점을 두고 2003년 초부터 2017년 말까지 전 세계적으로 일어난 주목할만한 사건에 대한 역사적 기록이 요약되어 있습니다.

형식: 탭으로 구분된 2열 데이터: 1) 발행일 및 2) 헤드라인 텍스트. 여기서는 헤드라인 텍스트에만 관심이 있습니다.

wget 'https://dataverse.harvard.edu/api/access/datafile/3450625?format=tab&gbrecs=true' -O raw.tsv
wc -l raw.tsv
head raw.tsv
--2022-12-14 22:23:14--  https://dataverse.harvard.edu/api/access/datafile/3450625?format=tab&gbrecs=true
Resolving dataverse.harvard.edu (dataverse.harvard.edu)... 44.213.44.146, 52.54.15.150, 52.23.87.139
Connecting to dataverse.harvard.edu (dataverse.harvard.edu)|44.213.44.146|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 57600231 (55M) [text/tab-separated-values]
Saving to: ‘raw.tsv’

raw.tsv             100%[===================>]  54.93M  64.6MB/s    in 0.8s    

2022-12-14 22:23:15 (64.6 MB/s) - ‘raw.tsv’ saved [57600231/57600231]

1103664 raw.tsv
publish_date    headline_text
20030219    "aba decides against community broadcasting licence"
20030219    "act fire witnesses must be aware of defamation"
20030219    "a g calls for infrastructure protection summit"
20030219    "air nz staff in aust strike for pay rise"
20030219    "air nz strike to affect australian travellers"
20030219    "ambitious olsson wins triple jump"
20030219    "antic delighted with record breaking barca"
20030219    "aussie qualifier stosur wastes four memphis match"
20030219    "aust addresses un security council over iraq"

단순화를 위해 헤드라인 텍스트만 유지하고 발행일은 제거합니다.

!rm -r corpus
!mkdir corpus

with open('corpus/text.txt', 'w') as out_file:
  with open('raw.tsv', 'r') as in_file:
    for line in in_file:
      headline = line.split('\t')[1].strip().strip('"')
      out_file.write(headline+"\n")
rm: cannot remove 'corpus': No such file or directory
tail corpus/text.txt
severe storms forecast for nye in south east queensland
snake catcher pleads for people not to kill reptiles
south australia prepares for party to welcome new year
strikers cool off the heat with big win in adelaide
stunning images from the sydney to hobart yacht
the ashes smiths warners near miss liven up boxing day test
timelapse: brisbanes new year fireworks
what 2017 meant to the kids of australia
what the papodopoulos meeting may mean for ausus
who is george papadopoulos the former trump campaign aide

2. 데이터에 대한 임베딩 생성하기

이 튜토리얼에서는 Neural Network Language Model(NNLM)를 사용하여 헤드라인 데이터에 대한 임베딩을 생성합니다. 그런 다음 문장 임베딩을 사용하여 문장 수준의 의미 유사성을 쉽게 계산할 수 있습니다. Apache Beam을 사용하여 임베딩 생성 프로세스를 실행합니다.

임베딩 추출 메서드

embed_fn = None

def generate_embeddings(text, model_url, random_projection_matrix=None):
  # Beam will run this function in different processes that need to
  # import hub and load embed_fn (if not previously loaded)
  global embed_fn
  if embed_fn is None:
    embed_fn = hub.load(model_url)
  embedding = embed_fn(text).numpy()
  if random_projection_matrix is not None:
    embedding = embedding.dot(random_projection_matrix)
  return text, embedding

tf.Example 메서드로 변환하기

def to_tf_example(entries):
  examples = []

  text_list, embedding_list = entries
  for i in range(len(text_list)):
    text = text_list[i]
    embedding = embedding_list[i]

    features = {
        'text': tf.train.Feature(
            bytes_list=tf.train.BytesList(value=[text.encode('utf-8')])),
        'embedding': tf.train.Feature(
            float_list=tf.train.FloatList(value=embedding.tolist()))
    }

    example = tf.train.Example(
        features=tf.train.Features(
            feature=features)).SerializeToString(deterministic=True)

    examples.append(example)

  return examples

Beam 파이프라인

def run_hub2emb(args):
  '''Runs the embedding generation pipeline'''

  options = beam.options.pipeline_options.PipelineOptions(**args)
  args = namedtuple("options", args.keys())(*args.values())

  with beam.Pipeline(args.runner, options=options) as pipeline:
    (
        pipeline
        | 'Read sentences from files' >> beam.io.ReadFromText(
            file_pattern=args.data_dir)
        | 'Batch elements' >> util.BatchElements(
            min_batch_size=args.batch_size, max_batch_size=args.batch_size)
        | 'Generate embeddings' >> beam.Map(
            generate_embeddings, args.model_url, args.random_projection_matrix)
        | 'Encode to tf example' >> beam.FlatMap(to_tf_example)
        | 'Write to TFRecords files' >> beam.io.WriteToTFRecord(
            file_path_prefix='{}/emb'.format(args.output_dir),
            file_name_suffix='.tfrecords')
    )

무작위 투영 가중치 행렬 생성하기

무작위 투영은 유클리드 공간에 있는 점 집합의 차원을 줄이는 데 사용되는 간단하지만 강력한 기술입니다. 이론적 배경은 Johnson-Lindenstrauss 보조 정리를 참조하세요.

무작위 투영으로 임베딩의 차원을 줄이면 ANN 인덱스를 빌드하고 쿼리하는 데 필요한 시간이 줄어듭니다.

이 튜토리얼에서는 Scikit-learn 라이브러리의 가우스 무작위 투영을 사용합니다.

def generate_random_projection_weights(original_dim, projected_dim):
  random_projection_matrix = None
  random_projection_matrix = gaussian_random_matrix(
      n_components=projected_dim, n_features=original_dim).T
  print("A Gaussian random weight matrix was creates with shape of {}".format(random_projection_matrix.shape))
  print('Storing random projection matrix to disk...')
  with open('random_projection_matrix', 'wb') as handle:
    pickle.dump(random_projection_matrix, 
                handle, protocol=pickle.HIGHEST_PROTOCOL)

  return random_projection_matrix

매개변수 설정하기

무작위 투영 없이 원래 임베딩 공간을 사용하여 인덱스를 빌드하려면 projected_dim 매개변수를 None으로 설정합니다. 그러면 높은 차원의 임베딩에 대한 인덱싱 스텝이 느려집니다.

파이프라인 실행하기

import tempfile

output_dir = tempfile.mkdtemp()
original_dim = hub.load(model_url)(['']).shape[1]
random_projection_matrix = None

if projected_dim:
  random_projection_matrix = generate_random_projection_weights(
      original_dim, projected_dim)

args = {
    'job_name': 'hub2emb-{}'.format(datetime.utcnow().strftime('%y%m%d-%H%M%S')),
    'runner': 'DirectRunner',
    'batch_size': 1024,
    'data_dir': 'corpus/*.txt',
    'output_dir': output_dir,
    'model_url': model_url,
    'random_projection_matrix': random_projection_matrix,
}

print("Pipeline args are set.")
args
A Gaussian random weight matrix was creates with shape of (128, 64)
Storing random projection matrix to disk...
Pipeline args are set.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:86: FutureWarning: Function gaussian_random_matrix is deprecated; gaussian_random_matrix is deprecated in 0.22 and will be removed in version 0.24.
  warnings.warn(msg, category=FutureWarning)
{'job_name': 'hub2emb-221214-222326',
 'runner': 'DirectRunner',
 'batch_size': 1024,
 'data_dir': 'corpus/*.txt',
 'output_dir': '/tmpfs/tmp/tmprec1hqj3',
 'model_url': 'https://tfhub.dev/google/nnlm-en-dim128/2',
 'random_projection_matrix': array([[-0.06535804,  0.01865796, -0.06064978, ..., -0.05015114,
         -0.13352614, -0.09640936],
        [-0.07904339, -0.15842449, -0.22100735, ...,  0.00894605,
         -0.14853026, -0.24176866],
        [-0.02937062, -0.26708865, -0.10866261, ...,  0.07562045,
          0.06678005,  0.22298068],
        ...,
        [-0.10069443, -0.16601407, -0.00384482, ..., -0.08650758,
         -0.05380853,  0.11275177],
        [ 0.15525274, -0.13222374,  0.13430359, ...,  0.02847312,
         -0.15499098, -0.20251122],
        [ 0.04419047,  0.17619267, -0.02210374, ..., -0.12165692,
         -0.08402846, -0.22452598]])}
print("Running pipeline...")
%time run_hub2emb(args)
print("Pipeline is done.")
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
Running pipeline...
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['-f', '/tmpfs/tmp/tmpbmlzolse.json', '--HistoryManager.hist_file=:memory:']
WARNING:apache_beam.options.pipeline_options:Discarding invalid overrides: {'batch_size': 1024, 'data_dir': 'corpus/*.txt', 'output_dir': '/tmpfs/tmp/tmprec1hqj3', 'model_url': 'https://tfhub.dev/google/nnlm-en-dim128/2', 'random_projection_matrix': array([[-0.06535804,  0.01865796, -0.06064978, ..., -0.05015114,
        -0.13352614, -0.09640936],
       [-0.07904339, -0.15842449, -0.22100735, ...,  0.00894605,
        -0.14853026, -0.24176866],
       [-0.02937062, -0.26708865, -0.10866261, ...,  0.07562045,
         0.06678005,  0.22298068],
       ...,
       [-0.10069443, -0.16601407, -0.00384482, ..., -0.08650758,
        -0.05380853,  0.11275177],
       [ 0.15525274, -0.13222374,  0.13430359, ...,  0.02847312,
        -0.15499098, -0.20251122],
       [ 0.04419047,  0.17619267, -0.02210374, ..., -0.12165692,
        -0.08402846, -0.22452598]])}
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
CPU times: user 32min 11s, sys: 35min 41s, total: 1h 7min 52s
Wall time: 2min 38s
Pipeline is done.
ls {output_dir}
emb-00000-of-00001.tfrecords

생성된 임베딩의 일부를 읽습니다.

embed_file = os.path.join(output_dir, 'emb-00000-of-00001.tfrecords')
sample = 5

# Create a description of the features.
feature_description = {
    'text': tf.io.FixedLenFeature([], tf.string),
    'embedding': tf.io.FixedLenFeature([projected_dim], tf.float32)
}

def _parse_example(example):
  # Parse the input `tf.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example, feature_description)

dataset = tf.data.TFRecordDataset(embed_file)
for record in dataset.take(sample).map(_parse_example):
  print("{}: {}".format(record['text'].numpy().decode('utf-8'), record['embedding'].numpy()[:10]))
headline_text: [ 0.14292637  0.02568962  0.11015486  0.17237647  0.05198897  0.0906129
 -0.05171417  0.04421039 -0.31293297 -0.11931964]
aba decides against community broadcasting licence: [ 0.15418081 -0.10138009 -0.11134841  0.0442351   0.11009002  0.15682779
  0.06588996  0.05799522 -0.01081489 -0.0627873 ]
act fire witnesses must be aware of defamation: [-0.01711873 -0.35322243  0.07454095 -0.06503551  0.07442788 -0.02514055
 -0.18165863  0.08065555 -0.11257842 -0.05815782]
a g calls for infrastructure protection summit: [ 0.08052633 -0.27404144  0.31831867  0.08820877  0.09934927  0.06961517
 -0.03280839 -0.04204937  0.01932975 -0.01367122]
air nz staff in aust strike for pay rise: [-0.01595928 -0.40558478 -0.04225743 -0.12859102  0.22375152 -0.04432241
  0.17953755 -0.37328407  0.18178187 -0.11260154]

3. 임베딩을 위한 ANN 인덱스 빌드하기

ANNOY(Approximate Nearest Neighbors Oh Yeah)는 주어진 쿼리 지점에 가까운 공간의 지점을 검색하기 위한 Python 바인딩이 있는 C++ 라이브러리입니다. 또한 메모리에 매핑되는 대규모 읽기 전용 파일 기반 데이터 구조를 생성합니다. 이는 음악 추천을 위해 Spotify에서 구축하고 사용합니다. 관심이 있으면 NGT, FAISS 등과 같은 ANNOY의 다른 대안도 시도해볼 수 있습니다.

def build_index(embedding_files_pattern, index_filename, vector_length, 
    metric='angular', num_trees=100):
  '''Builds an ANNOY index'''

  annoy_index = annoy.AnnoyIndex(vector_length, metric=metric)
  # Mapping between the item and its identifier in the index
  mapping = {}

  embed_files = tf.io.gfile.glob(embedding_files_pattern)
  num_files = len(embed_files)
  print('Found {} embedding file(s).'.format(num_files))

  item_counter = 0
  for i, embed_file in enumerate(embed_files):
    print('Loading embeddings in file {} of {}...'.format(i+1, num_files))
    dataset = tf.data.TFRecordDataset(embed_file)
    for record in dataset.map(_parse_example):
      text = record['text'].numpy().decode("utf-8")
      embedding = record['embedding'].numpy()
      mapping[item_counter] = text
      annoy_index.add_item(item_counter, embedding)
      item_counter += 1
      if item_counter % 100000 == 0:
        print('{} items loaded to the index'.format(item_counter))

  print('A total of {} items added to the index'.format(item_counter))

  print('Building the index with {} trees...'.format(num_trees))
  annoy_index.build(n_trees=num_trees)
  print('Index is successfully built.')

  print('Saving index to disk...')
  annoy_index.save(index_filename)
  print('Index is saved to disk.')
  print("Index file size: {} GB".format(
    round(os.path.getsize(index_filename) / float(1024 ** 3), 2)))
  annoy_index.unload()

  print('Saving mapping to disk...')
  with open(index_filename + '.mapping', 'wb') as handle:
    pickle.dump(mapping, handle, protocol=pickle.HIGHEST_PROTOCOL)
  print('Mapping is saved to disk.')
  print("Mapping file size: {} MB".format(
    round(os.path.getsize(index_filename + '.mapping') / float(1024 ** 2), 2)))
embedding_files = "{}/emb-*.tfrecords".format(output_dir)
embedding_dimension = projected_dim
index_filename = "index"

!rm {index_filename}
!rm {index_filename}.mapping

%time build_index(embedding_files, index_filename, embedding_dimension)
rm: cannot remove 'index': No such file or directory
rm: cannot remove 'index.mapping': No such file or directory
Found 1 embedding file(s).
Loading embeddings in file 1 of 1...
100000 items loaded to the index
200000 items loaded to the index
300000 items loaded to the index
400000 items loaded to the index
500000 items loaded to the index
600000 items loaded to the index
700000 items loaded to the index
800000 items loaded to the index
900000 items loaded to the index
1000000 items loaded to the index
1100000 items loaded to the index
A total of 1103664 items added to the index
Building the index with 100 trees...
Index is successfully built.
Saving index to disk...
Index is saved to disk.
Index file size: 1.61 GB
Saving mapping to disk...
Mapping is saved to disk.
Mapping file size: 50.61 MB
CPU times: user 9min 50s, sys: 1min, total: 10min 51s
Wall time: 3min 54s
ls
corpus         random_projection_matrix
index          raw.tsv
index.mapping  tf2_semantic_approximate_nearest_neighbors.ipynb

4. 유사성 일치에 인덱스 사용하기

이제 ANN 인덱스를 사용하여 의미상 입력 쿼리에 가까운 뉴스 헤드라인을 찾을 수 있습니다.

인덱스 및 매핑 파일 로드하기

index = annoy.AnnoyIndex(embedding_dimension)
index.load(index_filename, prefault=True)
print('Annoy index is loaded.')
with open(index_filename + '.mapping', 'rb') as handle:
  mapping = pickle.load(handle)
print('Mapping file is loaded.')
Annoy index is loaded.
/tmpfs/tmp/ipykernel_123860/1659470767.py:1: FutureWarning: The default argument for metric will be removed in future version of Annoy. Please pass metric='angular' explicitly.
  index = annoy.AnnoyIndex(embedding_dimension)
Mapping file is loaded.

유사성 일치 메서드

def find_similar_items(embedding, num_matches=5):
  '''Finds similar items to a given embedding in the ANN index'''
  ids = index.get_nns_by_vector(
  embedding, num_matches, search_k=-1, include_distances=False)
  items = [mapping[i] for i in ids]
  return items

주어진 쿼리에서 임베딩 추출하기

# Load the TF-Hub model
print("Loading the TF-Hub model...")
%time embed_fn = hub.load(model_url)
print("TF-Hub model is loaded.")

random_projection_matrix = None
if os.path.exists('random_projection_matrix'):
  print("Loading random projection matrix...")
  with open('random_projection_matrix', 'rb') as handle:
    random_projection_matrix = pickle.load(handle)
  print('random projection matrix is loaded.')

def extract_embeddings(query):
  '''Generates the embedding for the query'''
  query_embedding =  embed_fn([query])[0].numpy()
  if random_projection_matrix is not None:
    query_embedding = query_embedding.dot(random_projection_matrix)
  return query_embedding
Loading the TF-Hub model...
CPU times: user 541 ms, sys: 364 ms, total: 905 ms
Wall time: 935 ms
TF-Hub model is loaded.
Loading random projection matrix...
random projection matrix is loaded.
extract_embeddings("Hello Machine Learning!")[:10]
array([ 0.02059224,  0.23770048, -0.14262952, -0.05028363, -0.12286275,
        0.04908662, -0.01153165,  0.02967124, -0.00196674,  0.029905  ])

가장 유사한 항목을 찾기 위한 쿼리 입력하기

Generating embedding for the query...
CPU times: user 2.53 ms, sys: 2.2 ms, total: 4.73 ms
Wall time: 2.46 ms

Finding relevant items in the index...
CPU times: user 488 µs, sys: 0 ns, total: 488 µs
Wall time: 500 µs

Results:
=========
confronting global challenges
emerging nations to help struggling global economy
wa indigenous still facing barriers
old wisdom unites to solve global dilemmas
old wisdom unites to solve global dilemmas
conference examines challenges facing major cities
emerging currencies face looming threat from
fears global events may spark tourism war
pm warns military of difficult challenges ahead
mining conference considers green challenges

더 자세히 알고 싶나요?

tensorflow.org에서 TensorFlow에 대해 자세히 알아보고 tensorflow.org/hub에서 TF-Hub API 설명서를 확인할 수 있습니다. 추가적인 텍스트 임베딩 모듈 및 이미지 특성 벡터 모듈을 포함해 tfhub.dev에서 사용 가능한 TensorFlow Hub 모델을 찾아보세요.

빠르게 진행되는 Google의 머신러닝 실무 개요 과정인 머신러닝 집중 과정도 확인해 보세요.