Multilingual Universal Sentence Encoder Q&A Retrieval

View on Run in Google Colab View on GitHub Download notebook See TF Hub models

This is a demo for using Univeral Encoder Multilingual Q&A model for question-answer retrieval of text, illustrating the use of question_encoder and response_encoder of the model. We use sentences from SQuAD paragraphs as the demo dataset, each sentence and its context (the text surrounding the sentence) is encoded into high dimension embeddings with the response_encoder. These embeddings are stored in an index built using the simpleneighbors library for question-answer retrieval.

On retrieval a random question is selected from the SQuAD dataset and encoded into high dimension embedding with the question_encoder and query the simpleneighbors index returning a list of approximate nearest neighbors in semantic space.

More models

You can find all currently hosted text embedding models here and all models that have been trained on SQuAD as well here.


Setup Environment

# Install the latest Tensorflow version.
!pip install -q tensorflow_text
!pip install -q simpleneighbors[annoy]
!pip install -q nltk
!pip install -q tqdm

Setup common imports and functions

[nltk_data] Downloading package punkt to /home/kbuilder/nltk_data...
[nltk_data]   Unzipping tokenizers/

Run the following code block to download and extract the SQuAD dataset into:

  • sentences is a list of (text, context) tuples - each paragraph from the SQuAD dataset are splitted into sentences using nltk library and the sentence and paragraph text forms the (text, context) tuple.
  • questions is a list of (question, answer) tuples.

Download and extract SQuAD data

squad_url = ''

squad_json = download_squad(squad_url)
sentences = extract_sentences_from_squad_json(squad_json)
questions = extract_questions_from_squad_json(squad_json)
print("%s sentences, %s questions extracted from SQuAD %s" % (len(sentences), len(questions), squad_url))

print("\nExample sentence and context:\n")
sentence = random.choice(sentences)
10455 sentences, 10552 questions extracted from SQuAD

Example sentence and context:


('Warsaw Uprising Hill (121 metres (397.0 ft)), Szczęśliwice hill (138 metres '
 '(452.8 ft) – the highest point of Warsaw in general).')


('Warsaw lies in east-central Poland about 300 km (190 mi) from the Carpathian '
 'Mountains and about 260 km (160 mi) from the Baltic Sea, 523 km (325 mi) '
 'east of Berlin, Germany. The city straddles the Vistula River. It is located '
 'in the heartland of the Masovian Plain, and its average elevation is 100 '
 'metres (330 ft) above sea level. The highest point on the left side of the '
 'city lies at a height of 115.7 metres (379.6 ft) ("Redutowa" bus depot, '
 'district of Wola), on the right side – 122.1 metres (400.6 ft) ("Groszówka" '
 'estate, district of Wesoła, by the eastern border). The lowest point lies at '
 'a height 75.6 metres (248.0 ft) (at the right bank of the Vistula, by the '
 'eastern border of Warsaw). There are some hills (mostly artificial) located '
 'within the confines of the city – e.g. Warsaw Uprising Hill (121 metres '
 '(397.0 ft)), Szczęśliwice hill (138 metres (452.8 ft) – the highest point of '
 'Warsaw in general).')

The following code block setup the tensorflow graph g and session with the Univeral Encoder Multilingual Q&A model's question_encoder and response_encoder signatures.

Load model from tensorflow hub

The following code block compute the embeddings for all the text, context tuples and store them in a simpleneighbors index using the response_encoder.

Compute embeddings and build simpleneighbors index

Computing embeddings for 10455 sentences

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=104.0), HTML(value='')))

simpleneighbors index for 10455 sentences built.

On retrieval, the question is encoded using the question_encoder and the question embedding is used to query the simpleneighbors index.

Retrieve nearest neighbors for a random question from SQuAD

num_results = 25

query = random.choice(questions)
display_nearest_neighbors(query[0], query[1])