![]() |
![]() |
![]() |
![]() |
![]() |
这是使用 Univeral Encoder Multilingual Q&A 模型进行文本问答检索的演示,其中对模型的 question_encoder 和 response_encoder 的用法进行了说明。我们使用来自 SQuAD 段落的句子作为演示数据集,每个句子及其上下文(句子周围的文本)都使用 response_encoder 编码为高维嵌入向量。这些嵌入向量存储在使用 simpleneighbors 库构建的索引中,用于问答检索。
检索时,从 SQuAD 数据集中随机选择一个问题,并使用 question_encoder 将其编码为高维嵌入向量,然后查询 simpleneighbors 索引会返回语义空间中最近邻的列表。
更多模型
您可以在此处找到所有当前托管的文本嵌入向量模型,还可以在此处找到所有在 SQuADYou 上训练过的模型。
安装
Setup Environment
%%capture
# Install the latest Tensorflow version.
!pip install -q "tensorflow-text==2.8.*"
!pip install -q simpleneighbors[annoy]
!pip install -q nltk
!pip install -q tqdm
Setup common imports and functions
import json
import nltk
import os
import pprint
import random
import simpleneighbors
import urllib
from IPython.display import HTML, display
from tqdm.notebook import tqdm
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub
from tensorflow_text import SentencepieceTokenizer
nltk.download('punkt')
def download_squad(url):
return json.load(urllib.request.urlopen(url))
def extract_sentences_from_squad_json(squad):
all_sentences = []
for data in squad['data']:
for paragraph in data['paragraphs']:
sentences = nltk.tokenize.sent_tokenize(paragraph['context'])
all_sentences.extend(zip(sentences, [paragraph['context']] * len(sentences)))
return list(set(all_sentences)) # remove duplicates
def extract_questions_from_squad_json(squad):
questions = []
for data in squad['data']:
for paragraph in data['paragraphs']:
for qas in paragraph['qas']:
if qas['answers']:
questions.append((qas['question'], qas['answers'][0]['text']))
return list(set(questions))
def output_with_highlight(text, highlight):
output = "<li> "
i = text.find(highlight)
while True:
if i == -1:
output += text
break
output += text[0:i]
output += '<b>'+text[i:i+len(highlight)]+'</b>'
text = text[i+len(highlight):]
i = text.find(highlight)
return output + "</li>\n"
def display_nearest_neighbors(query_text, answer_text=None):
query_embedding = model.signatures['question_encoder'](tf.constant([query_text]))['outputs'][0]
search_results = index.nearest(query_embedding, n=num_results)
if answer_text:
result_md = '''
<p>Random Question from SQuAD:</p>
<p> <b>%s</b></p>
<p>Answer:</p>
<p> <b>%s</b></p>
''' % (query_text , answer_text)
else:
result_md = '''
<p>Question:</p>
<p> <b>%s</b></p>
''' % query_text
result_md += '''
<p>Retrieved sentences :
<ol>
'''
if answer_text:
for s in search_results:
result_md += output_with_highlight(s, answer_text)
else:
for s in search_results:
result_md += '<li>' + s + '</li>\n'
result_md += "</ol>"
display(HTML(result_md))
[nltk_data] Downloading package punkt to /home/kbuilder/nltk_data... [nltk_data] Unzipping tokenizers/punkt.zip.
运行以下代码块,下载并将 SQuAD 数据集提取到:
- 句子是(文本, 上下文)元组的列表,SQuAD 数据集中的每个段落都用 NLTK 库拆分成句子,并且句子和段落文本构成(文本, 上下文)元组。
- 问题是(问题, 答案)元组的列表。
注:您可以选择下面的 squad_url,使用本演示为 SQuAD 训练数据集或较小的 dev 数据集(1.1 或 2.0)建立索引。
Download and extract SQuAD data
squad_url = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json'
squad_json = download_squad(squad_url)
sentences = extract_sentences_from_squad_json(squad_json)
questions = extract_questions_from_squad_json(squad_json)
print("%s sentences, %s questions extracted from SQuAD %s" % (len(sentences), len(questions), squad_url))
print("\nExample sentence and context:\n")
sentence = random.choice(sentences)
print("sentence:\n")
pprint.pprint(sentence[0])
print("\ncontext:\n")
pprint.pprint(sentence[1])
print()
10455 sentences, 10552 questions extracted from SQuAD https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json Example sentence and context: sentence: ('Students at the University of Chicago run over 400 clubs and organizations ' 'known as Recognized Student Organizations (RSOs).') context: ('Students at the University of Chicago run over 400 clubs and organizations ' 'known as Recognized Student Organizations (RSOs). These include cultural and ' 'religious groups, academic clubs and teams, and common-interest ' 'organizations. Notable extracurricular groups include the University of ' 'Chicago College Bowl Team, which has won 118 tournaments and 15 national ' "championships, leading both categories internationally. The university's " 'competitive Model United Nations team was the top ranked team in North ' "America in 2013-14 and 2014-2015. Among notable RSOs are the nation's " 'longest continuously running student film society Doc Films, organizing ' 'committee for the University of Chicago Scavenger Hunt, the twice-weekly ' 'student newspaper The Chicago Maroon, the alternative weekly student ' "newspaper South Side Weekly, the nation's second oldest continuously running " 'student improvisational theater troupe Off-Off Campus, and the ' 'university-owned radio station WHPK.')
以下代码块使用 Universal Encoder Multilingual Q&A 模型的 question_encoder 和 response_encoder 签名对 TensorFlow 计算图 g 和会话进行设置。
Load model from tensorflow hub
module_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual-qa/3"
model = hub.load(module_url)
以下代码块计算所有文本的嵌入向量和上下文元组,并使用 response_encoder 将它们存储在 simpleneighbors 索引中。
Compute embeddings and build simpleneighbors index
batch_size = 100
encodings = model.signatures['response_encoder'](
input=tf.constant([sentences[0][0]]),
context=tf.constant([sentences[0][1]]))
index = simpleneighbors.SimpleNeighbors(
len(encodings['outputs'][0]), metric='angular')
print('Computing embeddings for %s sentences' % len(sentences))
slices = zip(*(iter(sentences),) * batch_size)
num_batches = int(len(sentences) / batch_size)
for s in tqdm(slices, total=num_batches):
response_batch = list([r for r, c in s])
context_batch = list([c for r, c in s])
encodings = model.signatures['response_encoder'](
input=tf.constant(response_batch),
context=tf.constant(context_batch)
)
for batch_index, batch in enumerate(response_batch):
index.add_one(batch, encodings['outputs'][batch_index])
index.build()
print('simpleneighbors index for %s sentences built.' % len(sentences))
Computing embeddings for 10455 sentences 0%| | 0/104 [00:00<?, ?it/s] simpleneighbors index for 10455 sentences built.
检索时,使用 question_encoder 对问题进行编码,而问题嵌入向量用于查询 simpleneighbors 索引。
Retrieve nearest neighbors for a random question from SQuAD
num_results = 25
query = random.choice(questions)
display_nearest_neighbors(query[0], query[1])