ตัวเข้ารหัสประโยคสากล

ดูบน TensorFlow.org ทำงานใน Google Colab ดูบน GitHub ดาวน์โหลดโน๊ตบุ๊ค ดูรุ่น TF Hub

สมุดบันทึกนี้แสดงวิธีเข้าถึง Universal Sentence Encoder และใช้สำหรับงานที่มีความคล้ายคลึงของประโยคและการจัดประเภทประโยค

Universal Sentence Encoder ทำให้การฝังระดับประโยคทำได้ง่ายเหมือนกับการค้นหาการฝังสำหรับคำแต่ละคำในอดีต การฝังประโยคสามารถใช้เล็กน้อยในการคำนวณระดับประโยคซึ่งหมายถึงความคล้ายคลึงกัน และเพื่อให้มีประสิทธิภาพที่ดีขึ้นในงานจำแนกประเภทดาวน์สตรีมโดยใช้ข้อมูลการฝึกอบรมที่มีการดูแลน้อยกว่า

ติดตั้ง

ส่วนนี้ตั้งค่าสภาพแวดล้อมสำหรับการเข้าถึง Universal Sentence Encoder บน TF Hub และแสดงตัวอย่างการใช้ตัวเข้ารหัสกับคำ ประโยค และย่อหน้า

%%capture
!pip3 install seaborn

ข้อมูลรายละเอียดเพิ่มเติมเกี่ยวกับการติดตั้ง Tensorflow สามารถพบได้ที่ https://www.tensorflow.org/install/

โหลดโมดูล TF Hub ของ Universal Sentence Encoder

module https://tfhub.dev/google/universal-sentence-encoder/4 loaded

คำนวณการแสดงสำหรับแต่ละข้อความ โดยแสดงความยาวต่างๆ ที่รองรับ

Message: Elephant
Embedding size: 512
Embedding: [0.008344474248588085, 0.00048079612315632403, 0.06595245748758316, ...]

Message: I am a sentence for which I would like to get its embedding.
Embedding size: 512
Embedding: [0.05080860108137131, -0.016524313017725945, 0.015737781301140785, ...]

Message: Universal Sentence Encoder embeddings also support short paragraphs. There is no hard limit on how long the paragraph is. Roughly, the longer the more 'diluted' the embedding will be.
Embedding size: 512
Embedding: [-0.028332678601145744, -0.05586216226220131, -0.012941479682922363, ...]

ตัวอย่างงานความคล้ายคลึงกันของความหมาย

การฝังที่สร้างโดย Universal Sentence Encoder จะถูกทำให้เป็นมาตรฐานโดยประมาณ ความคล้ายคลึงกันทางความหมายของสองประโยคสามารถคำนวณได้เล็กน้อยในฐานะผลิตภัณฑ์ภายในของการเข้ารหัส

def plot_similarity(labels, features, rotation):
  corr = np.inner(features, features)
  sns.set(font_scale=1.2)
  g = sns.heatmap(
      corr,
      xticklabels=labels,
      yticklabels=labels,
      vmin=0,
      vmax=1,
      cmap="YlOrRd")
  g.set_xticklabels(labels, rotation=rotation)
  g.set_title("Semantic Textual Similarity")

def run_and_plot(messages_):
  message_embeddings_ = embed(messages_)
  plot_similarity(messages_, message_embeddings_, 90)

ความคล้ายคลึงกันที่มองเห็นได้

ที่นี่เราแสดงความคล้ายคลึงกันในแผนที่ความร้อน กราฟสุดท้ายคือเมทริกซ์ 9x9 ที่แต่ละรายการ [i, j] เป็นสีขึ้นอยู่กับผลิตภัณฑ์ด้านการเข้ารหัสสำหรับประโยคที่ i และ j

messages = [
    # Smartphones
    "I like my phone",
    "My phone is not good.",
    "Your cellphone looks great.",

    # Weather
    "Will it snow tomorrow?",
    "Recently a lot of hurricanes have hit the US",
    "Global warming is real",

    # Food and health
    "An apple a day, keeps the doctors away",
    "Eating strawberries is healthy",
    "Is paleo better than keto?",

    # Asking about age
    "How old are you?",
    "what is your age?",
]

run_and_plot(messages)

png

การประเมินผล: เกณฑ์มาตรฐาน STS (ความคล้ายคลึงข้อความเชิงความหมาย)

STS เกณฑ์มาตรฐาน ให้การประเมินผลที่แท้จริงของระดับที่คะแนนความคล้ายคลึงกันคำนวณโดยใช้ embeddings ประโยคสอดคล้องกับคำตัดสินของมนุษย์ เกณฑ์มาตรฐานกำหนดให้ระบบแสดงผลคะแนนความคล้ายคลึงกันสำหรับการเลือกคู่ประโยคที่หลากหลาย เพียร์สันความสัมพันธ์ ที่ใช้แล้วเพื่อประเมินคุณภาพของคะแนนที่เครื่องคล้ายคลึงกันกับคำตัดสินของมนุษย์

ดาวน์โหลดข้อมูล

import pandas
import scipy
import math
import csv

sts_dataset = tf.keras.utils.get_file(
    fname="Stsbenchmark.tar.gz",
    origin="http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz",
    extract=True)
sts_dev = pandas.read_table(
    os.path.join(os.path.dirname(sts_dataset), "stsbenchmark", "sts-dev.csv"),
    error_bad_lines=False,
    skip_blank_lines=True,
    usecols=[4, 5, 6],
    names=["sim", "sent_1", "sent_2"])
sts_test = pandas.read_table(
    os.path.join(
        os.path.dirname(sts_dataset), "stsbenchmark", "sts-test.csv"),
    error_bad_lines=False,
    quoting=csv.QUOTE_NONE,
    skip_blank_lines=True,
    usecols=[4, 5, 6],
    names=["sim", "sent_1", "sent_2"])
# cleanup some NaN values in sts_dev
sts_dev = sts_dev[[isinstance(s, str) for s in sts_dev['sent_2']]]
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py:3444: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version.


  exec(code_obj, self.user_global_ns, self.user_ns)

ประเมินการฝังประโยค

sts_data = sts_dev

def run_sts_benchmark(batch):
  sts_encode1 = tf.nn.l2_normalize(embed(tf.constant(batch['sent_1'].tolist())), axis=1)
  sts_encode2 = tf.nn.l2_normalize(embed(tf.constant(batch['sent_2'].tolist())), axis=1)
  cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
  clip_cosine_similarities = tf.clip_by_value(cosine_similarities, -1.0, 1.0)
  scores = 1.0 - tf.acos(clip_cosine_similarities) / math.pi
  """Returns the similarity scores"""
  return scores

dev_scores = sts_data['sim'].tolist()
scores = []
for batch in np.array_split(sts_data, 10):
  scores.extend(run_sts_benchmark(batch))

pearson_correlation = scipy.stats.pearsonr(scores, dev_scores)
print('Pearson correlation coefficient = {0}\np-value = {1}'.format(
    pearson_correlation[0], pearson_correlation[1]))
Pearson correlation coefficient = 0.8036394630692778
p-value = 0.0