TF-Hub の BERT エキスパート

コレクションでコンテンツを整理 必要に応じて、コンテンツの保存と分類を行います。

TensorFlow.org で実行 Google Colabで実行 GitHub でソースを表示 ノートブックをダウンロード TF Hub モデルを参照

この Colab では、以下の方法を実演します。

  • MNLI、SQuAD、PubMed など、さまざまなタスクでトレーニング済みの BERT モデルを TensorFlow Hub から読み込みます。
  • 一致する事前処理モデルを使用して、未加工のテキストをトークン化して ID に変換します。
  • 読み込んだモデルを使用して、トークン入力 ID からプールされたシーケンス出力を生成します。
  • 異なる文のプールされた出力に見られる意味的類似性を確認します。

注意: この Colab の実行には GPU ランタイムを使用してください。

セットアップとインポート

pip install --quiet "tensorflow-text==2.8.*"
import seaborn as sns
from sklearn.metrics import pairwise

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text  # Imports TF ops for preprocessing.

Configure the model

文章

Wikipedia から取得した、モデルで実行するための文章を見てみましょう。

sentences = [
  "Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.",
  "The album went straight to number one on the Norwegian album chart, and sold to double platinum.",
  "Among the singles released from the album were the songs \"Be My Lover\" and \"Hard To Stay Awake\".",
  "Riccardo Zegna is an Italian jazz musician.",
  "Rajko Maksimović is a composer, writer, and music pedagogue.",
  "One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.",
  "Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum",
  "A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.",
  "A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth.",
]

モデルの実行

TF-Hub から BERT モデルを読み込み、TF-Hub の一致する事前処理モデルを使用して文章をトークン化し、そのトークン化された文章をモデルにフィードします。この Colab を端的に進められるよう、GPU で実行することをお勧めします。

RuntimeChange runtime type に移動して、GPU が選択されていることを確認します。

preprocess = hub.load(PREPROCESS_MODEL)
bert = hub.load(BERT_MODEL)
inputs = preprocess(sentences)
outputs = bert(inputs)
print("Sentences:")
print(sentences)

print("\nBERT inputs:")
print(inputs)

print("\nPooled embeddings:")
print(outputs["pooled_output"])

print("\nPer token embeddings:")
print(outputs["sequence_output"])
Sentences:
["Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.", 'The album went straight to number one on the Norwegian album chart, and sold to double platinum.', 'Among the singles released from the album were the songs "Be My Lover" and "Hard To Stay Awake".', 'Riccardo Zegna is an Italian jazz musician.', 'Rajko Maksimović is a composer, writer, and music pedagogue.', 'One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.', 'Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum', 'A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.', "A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth."]

BERT inputs:
{'input_type_ids': <tf.Tensor: shape=(9, 128), dtype=int32, numpy=
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>, 'input_mask': <tf.Tensor: shape=(9, 128), dtype=int32, numpy=
array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'input_word_ids': <tf.Tensor: shape=(9, 128), dtype=int32, numpy=
array([[  101,  2182,  2057, ...,     0,     0,     0],
       [  101,  1996,  2201, ...,     0,     0,     0],
       [  101,  2426,  1996, ...,     0,     0,     0],
       ...,
       [  101, 16447,  6714, ...,     0,     0,     0],
       [  101,  1037,  5943, ...,     0,     0,     0],
       [  101,  1037,  7704, ...,     0,     0,     0]], dtype=int32)>}

Pooled embeddings:
tf.Tensor(
[[ 0.7975983  -0.4858047   0.49781665 ... -0.34488207  0.3972758
  -0.20639578]
 [ 0.5712035  -0.41205317  0.70489097 ... -0.35185057  0.19032365
  -0.40419084]
 [-0.6993837   0.1586686   0.06569945 ... -0.06232291 -0.81550217
  -0.07923597]
 ...
 [-0.3572722   0.77089787  0.15756367 ...  0.44185576 -0.86448324
   0.04504809]
 [ 0.9107701   0.41501644  0.5606342  ... -0.49263844  0.3964056
  -0.05036103]
 [ 0.90502876 -0.15505227  0.72672117 ... -0.34734455  0.50526446
  -0.19542967]], shape=(9, 768), dtype=float32)

Per token embeddings:
tf.Tensor(
[[[ 1.0919763e+00 -5.3055435e-01  5.4639924e-01 ... -3.5962319e-01
    4.2041004e-01 -2.0940384e-01]
  [ 1.0143832e+00  7.8078997e-01  8.5375911e-01 ...  5.5282390e-01
   -1.1245768e+00  5.6027830e-01]
  [ 7.8862834e-01  7.7776447e-02  9.5150828e-01 ... -1.9075394e-01
    5.9206229e-01  6.1910677e-01]
  ...
  [-3.2203096e-01 -4.2521316e-01 -1.2823755e-01 ... -3.9094931e-01
   -7.9097426e-01  4.2236397e-01]
  [-3.1037472e-02  2.3985589e-01 -2.1994336e-01 ... -1.1440081e-01
   -1.2680490e+00 -1.6136405e-01]
  [-4.2063668e-01  5.4972923e-01 -3.2444507e-01 ... -1.8478569e-01
   -1.1342961e+00 -5.8976438e-02]]

 [[ 6.4930725e-01 -4.3808180e-01  8.7695575e-01 ... -3.6755425e-01
    1.9267297e-01 -4.2864799e-01]
  [-1.1248751e+00  2.9931432e-01  1.1799647e+00 ...  4.8729539e-01
    5.3400397e-01  2.2836086e-01]
  [-2.7057484e-01  3.2353774e-02  1.0425684e+00 ...  5.8993781e-01
    1.5367906e+00  5.8425695e-01]
  ...
  [-1.4762504e+00  1.8239306e-01  5.5877924e-02 ... -1.6733217e+00
   -6.7398900e-01 -7.2449714e-01]
  [-1.5138137e+00  5.8184761e-01  1.6141929e-01 ... -1.2640836e+00
   -4.0272185e-01 -9.7197187e-01]
  [-4.7152787e-01  2.2817361e-01  5.2776086e-01 ... -7.5483733e-01
   -9.0903133e-01 -1.6954741e-01]]

 [[-8.6609292e-01  1.6002062e-01  6.5794230e-02 ... -6.2403791e-02
   -1.1432397e+00 -7.9402432e-02]
  [ 7.7118009e-01  7.0804596e-01  1.1350013e-01 ...  7.8830987e-01
   -3.1438011e-01 -9.7487241e-01]
  [-4.4002396e-01 -3.0059844e-01  3.5479474e-01 ...  7.9736769e-02
   -4.7393358e-01 -1.1001850e+00]
  ...
  [-1.0205296e+00  2.6938295e-01 -4.7310317e-01 ... -6.6319406e-01
   -1.4579906e+00 -3.4665293e-01]
  [-9.7003269e-01 -4.5014530e-02 -5.9779799e-01 ... -3.0526215e-01
   -1.2744255e+00 -2.8051612e-01]
  [-7.3144299e-01  1.7699258e-01 -4.6257949e-01 ... -1.6062324e-01
   -1.6346085e+00 -3.2060498e-01]]

 ...

 [[-3.7375548e-01  1.0225370e+00  1.5888736e-01 ...  4.7453445e-01
   -1.3108220e+00  4.5078602e-02]
  [-4.1589195e-01  5.0019342e-01 -4.5844358e-01 ...  4.1482633e-01
   -6.2065941e-01 -7.1554971e-01]
  [-1.2504396e+00  5.0936830e-01 -5.7103878e-01 ...  3.5491806e-01
    2.4368122e-01 -2.0577202e+00]
  ...
  [ 1.3393565e-01  1.1859145e+00 -2.2170596e-01 ... -8.1946641e-01
   -1.6737353e+00 -3.9692396e-01]
  [-3.3662772e-01  1.6556194e+00 -3.7813133e-01 ... -9.6745455e-01
   -1.4801090e+00 -8.3330792e-01]
  [-2.2649661e-01  1.6178432e+00 -6.7044818e-01 ... -4.9078292e-01
   -1.4535757e+00 -7.1707249e-01]]

 [[ 1.5320230e+00  4.4165635e-01  6.3375759e-01 ... -5.3953838e-01
    4.1937724e-01 -5.0403673e-02]
  [ 8.9377761e-01  8.9395475e-01  3.0627429e-02 ...  5.9038877e-02
   -2.0649567e-01 -8.4811318e-01]
  [-1.8558376e-02  1.0479058e+00 -1.3329605e+00 ... -1.3869658e-01
   -3.7879506e-01 -4.9068686e-01]
  ...
  [ 1.4275625e+00  1.0696868e-01 -4.0634036e-02 ... -3.1777412e-02
   -4.1459864e-01  7.0036912e-01]
  [ 1.1286640e+00  1.4547867e-01 -6.1372513e-01 ...  4.7491822e-01
   -3.9852142e-01  4.3124473e-01]
  [ 1.4393290e+00  1.8030715e-01 -4.2854571e-01 ... -2.5022799e-01
   -1.0000539e+00  3.5985443e-01]]

 [[ 1.4993387e+00 -1.5631306e-01  9.2174339e-01 ... -3.6242083e-01
    5.5635023e-01 -1.9797631e-01]
  [ 1.1110525e+00  3.6651248e-01  3.5505861e-01 ... -5.4297489e-01
    1.4471433e-01 -3.1676081e-01]
  [ 2.4048671e-01  3.8116074e-01 -5.9182751e-01 ...  3.7410957e-01
   -5.9829539e-01 -1.0166274e+00]
  ...
  [ 1.0158602e+00  5.0260085e-01  1.0736975e-01 ... -9.5642674e-01
   -4.1039643e-01 -2.6760373e-01]
  [ 1.1848910e+00  6.5479511e-01  1.0155141e-03 ... -8.6154616e-01
   -8.8041753e-02 -3.0636895e-01]
  [ 1.2669089e+00  4.7767794e-01  6.6289604e-03 ... -1.1585804e+00
   -7.0679039e-02 -1.8678637e-01]]], shape=(9, 128, 768), dtype=float32)

意味的類似性

では、文章の pooled_output 埋め込みを確認し、文章間でどれくらい類似性があるのかを比較しましょう。

Helper functions

plot_similarity(outputs["pooled_output"], sentences)

png

詳細情報

  • その他の BERT モデルは、TensorFlow Hub をご覧ください。
  • このノートブックでは BERT を使った単純な推論を実演しています。BERT のファインチューニングに関するより高度なチュートリアルについては、tensorflow.org/official_models/fine_tuning_bert をご覧ください。
  • モデルの実行には、GPU チップを 1 つしか使用していません。tf.distribute を使ったモデルの読み込み方法については、tensorflow.org/tutorials/distribute/save_and_load をご覧ください。