![]() | ![]() | ![]() | ![]() | ![]() |
この例では、tensorflow-modelsPIPパッケージを使用してBERTモデルを微調整します。
このチュートリアルのベースとなる事前トレーニング済みのBERTモデルは、 TensorFlowハブでも利用できます。使用方法については、ハブの付録を参照してください。
セットアップ
TensorFlow Model Gardenpipパッケージをインストールします
tf-models-official
は、安定したModelGardenパッケージです。tensorflow_models
リポジトリの最新の変更が含まれていない可能性があることに注意してください。最新の変更を含めるには、tf-models-nightly
インストールします。これは、毎日自動的に作成される夜間のModelGardenパッケージです。- pipは、すべてのモデルと依存関係を自動的にインストールします。
pip install -q tf-models-official==2.3.0
WARNING: You are using pip version 20.2.3; however, version 20.2.4 is available. You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
輸入
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
from official.modeling import tf_utils
from official import nlp
from official.nlp import bert
# Load the required submodules
import official.nlp.optimization
import official.nlp.bert.bert_models
import official.nlp.bert.configs
import official.nlp.bert.run_classifier
import official.nlp.bert.tokenization
import official.nlp.data.classifier_data_lib
import official.nlp.modeling.losses
import official.nlp.modeling.models
import official.nlp.modeling.networks
リソース
このディレクトリには、このチュートリアルで使用される構成、語彙、および事前にトレーニングされたチェックポイントが含まれています。
gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12"
tf.io.gfile.listdir(gs_folder_bert)
['bert_config.json', 'bert_model.ckpt.data-00000-of-00001', 'bert_model.ckpt.index', 'vocab.txt']
TensorFlowHubから事前トレーニング済みのBERTエンコーダーを入手できます。
hub_url_bert = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2"
データ
この例では、 TFDSのGLUEMRPCデータセットを使用しました。
このデータセットは、BERTモデルに直接入力できるように設定されていないため、このセクションでは必要な前処理も処理します。
TensorFlowデータセットからデータセットを取得します
Microsoft Research Paraphrase Corpus(Dolan&Brockett、2005)は、オンラインニュースソースから自動的に抽出された文のペアのコーパスであり、ペアの文が意味的に同等であるかどうかを人間が注釈で示しています。
- ラベルの数:2。
- トレーニングデータセットのサイズ:3668。
- 評価データセットのサイズ:408。
- トレーニングおよび評価データセットの最大シーケンス長:128。
glue, info = tfds.load('glue/mrpc', with_info=True,
# It's small, load the whole dataset
batch_size=-1)
Downloading and preparing dataset glue/mrpc/1.0.0 (download: 1.43 MiB, generated: Unknown size, total: 1.43 MiB) to /home/kbuilder/tensorflow_datasets/glue/mrpc/1.0.0... Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/glue/mrpc/1.0.0.incompleteKZIBN9/glue-train.tfrecord Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/glue/mrpc/1.0.0.incompleteKZIBN9/glue-validation.tfrecord Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/glue/mrpc/1.0.0.incompleteKZIBN9/glue-test.tfrecord Dataset glue downloaded and prepared to /home/kbuilder/tensorflow_datasets/glue/mrpc/1.0.0. Subsequent calls will reuse this data.
list(glue.keys())
['test', 'train', 'validation']
info
オブジェクトは、データセットとその機能を説明します。
info.features
FeaturesDict({ 'idx': tf.int32, 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=2), 'sentence1': Text(shape=(), dtype=tf.string), 'sentence2': Text(shape=(), dtype=tf.string), })
2つのクラスは次のとおりです。
info.features['label'].names
['not_equivalent', 'equivalent']
トレーニングセットの例を次に示します。
glue_train = glue['train']
for key, value in glue_train.items():
print(f"{key:9s}: {value[0].numpy()}")
idx : 1680 label : 0 sentence1: b'The identical rovers will act as robotic geologists , searching for evidence of past water .' sentence2: b'The rovers act as robotic geologists , moving on six wheels .'
BERTトークナイザー
事前トレーニング済みのモデルを微調整するには、トレーニング中に使用したものとまったく同じトークン化、語彙、およびインデックスマッピングを使用していることを確認する必要があります。
このチュートリアルで使用されるBERTトークナイザーは、純粋なPythonで記述されています(TensorFlow opsから構築されたものではありません)。だから、同じようにあなたのモデルにプラグインすることはできませんkeras.layer
あなたがでできるようにpreprocessing.TextVectorization
。
次のコードは、基本モデルで使用されていたトークナイザーを再構築します。
# Set up tokenizer to generate Tensorflow dataset
tokenizer = bert.tokenization.FullTokenizer(
vocab_file=os.path.join(gs_folder_bert, "vocab.txt"),
do_lower_case=True)
print("Vocab size:", len(tokenizer.vocab))
Vocab size: 30522
文をトークン化する:
tokens = tokenizer.tokenize("Hello TensorFlow!")
print(tokens)
ids = tokenizer.convert_tokens_to_ids(tokens)
print(ids)
['hello', 'tensor', '##flow', '!'] [7592, 23435, 12314, 999]
データを前処理する
このセクションでは、データセットをモデルで期待される形式に手動で前処理しました。
このデータセットは小さいため、メモリ内で前処理をすばやく簡単に実行できます。より大きなデータセットの場合、 tf_models
ライブラリには、データセットを前処理および再シリアル化するためのツールがいくつか含まれています。詳細については、付録:大規模なデータセットの再エンコードを参照してください。
文をエンコードする
モデルは、2つの入力文が連結されることを想定しています。この入力は、 [CLS]
「これは分類の問題です」トークンで始まると予想され、各文は[SEP]
「セパレーター」トークンで終わる必要があります。
tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]'])
[101, 102]
[SEP]
トークンを追加しながらすべての文をエンコードし、それらを不規則テンソルにパックすることから始めます。
def encode_sentence(s):
tokens = list(tokenizer.tokenize(s.numpy()))
tokens.append('[SEP]')
return tokenizer.convert_tokens_to_ids(tokens)
sentence1 = tf.ragged.constant([
encode_sentence(s) for s in glue_train["sentence1"]])
sentence2 = tf.ragged.constant([
encode_sentence(s) for s in glue_train["sentence2"]])
print("Sentence1 shape:", sentence1.shape.as_list())
print("Sentence2 shape:", sentence2.shape.as_list())
Sentence1 shape: [3668, None] Sentence2 shape: [3668, None]
次に、 [CLS]
トークンをinput_word_ids
し、不規則なテンソルを連結して、例ごとに1つのinput_word_ids
テンソルを形成します。 RaggedTensor.to_tensor()
は、最長のシーケンスにRaggedTensor.to_tensor()
ます。
cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*sentence1.shape[0]
input_word_ids = tf.concat([cls, sentence1, sentence2], axis=-1)
_ = plt.pcolormesh(input_word_ids.to_tensor())
マスクと入力タイプ
モデルは、2つの追加入力を想定しています。
- 入力マスク
- 入力タイプ
マスクを使用すると、モデルでコンテンツとパディングを明確に区別できます。マスクはinput_word_ids
と同じ形状であり、 input_word_ids
がパディングされていない場所には1
が含まれます。
input_mask = tf.ones_like(input_word_ids).to_tensor()
plt.pcolormesh(input_mask)
<matplotlib.collections.QuadMesh at 0x7fad1c07ed30>
「入力タイプ」も同じ形状ですが、パディングされていない領域内に、トークンがどの文に含まれているかを示す0
または1
が含まれています。
type_cls = tf.zeros_like(cls)
type_s1 = tf.zeros_like(sentence1)
type_s2 = tf.ones_like(sentence2)
input_type_ids = tf.concat([type_cls, type_s1, type_s2], axis=-1).to_tensor()
plt.pcolormesh(input_type_ids)
<matplotlib.collections.QuadMesh at 0x7fad143c1710>
すべてをまとめる
上記のテキスト解析コードを1つの関数に収集し、 glue/mrpc
データセットの各分割に適用します。
def encode_sentence(s, tokenizer):
tokens = list(tokenizer.tokenize(s))
tokens.append('[SEP]')
return tokenizer.convert_tokens_to_ids(tokens)
def bert_encode(glue_dict, tokenizer):
num_examples = len(glue_dict["sentence1"])
sentence1 = tf.ragged.constant([
encode_sentence(s, tokenizer)
for s in np.array(glue_dict["sentence1"])])
sentence2 = tf.ragged.constant([
encode_sentence(s, tokenizer)
for s in np.array(glue_dict["sentence2"])])
cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*sentence1.shape[0]
input_word_ids = tf.concat([cls, sentence1, sentence2], axis=-1)
input_mask = tf.ones_like(input_word_ids).to_tensor()
type_cls = tf.zeros_like(cls)
type_s1 = tf.zeros_like(sentence1)
type_s2 = tf.ones_like(sentence2)
input_type_ids = tf.concat(
[type_cls, type_s1, type_s2], axis=-1).to_tensor()
inputs = {
'input_word_ids': input_word_ids.to_tensor(),
'input_mask': input_mask,
'input_type_ids': input_type_ids}
return inputs
glue_train = bert_encode(glue['train'], tokenizer)
glue_train_labels = glue['train']['label']
glue_validation = bert_encode(glue['validation'], tokenizer)
glue_validation_labels = glue['validation']['label']
glue_test = bert_encode(glue['test'], tokenizer)
glue_test_labels = glue['test']['label']
データの各サブセットは、機能の辞書とラベルのセットに変換されています。入力辞書の各特徴は同じ形状であり、ラベルの数は一致する必要があります。
for key, value in glue_train.items():
print(f'{key:15s} shape: {value.shape}')
print(f'glue_train_labels shape: {glue_train_labels.shape}')
input_word_ids shape: (3668, 103) input_mask shape: (3668, 103) input_type_ids shape: (3668, 103) glue_train_labels shape: (3668,)
モデル
モデルを構築する
最初のステップは、事前にトレーニングされたモデルの構成をダウンロードすることです。
import json
bert_config_file = os.path.join(gs_folder_bert, "bert_config.json")
config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())
bert_config = bert.configs.BertConfig.from_dict(config_dict)
config_dict
{'attention_probs_dropout_prob': 0.1, 'hidden_act': 'gelu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'intermediate_size': 3072, 'max_position_embeddings': 512, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'type_vocab_size': 2, 'vocab_size': 30522}
config
コアの出力を予測するKerasモデルであるBERTモデル、定義num_classes
最大の系列長を有する入力からmax_seq_length
。
この関数は、エンコーダーと分類子の両方を返します。
bert_classifier, bert_encoder = bert.bert_models.classifier_model(
bert_config, num_labels=2)
分類器には、3つの入力と1つの出力があります。
tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)
トレーニングセットからの10例のデータのテストバッチで実行します。出力は、2つのクラスのロジットです。
glue_batch = {key: val[:10] for key, val in glue_train.items()}
bert_classifier(
glue_batch, training=True
).numpy()
array([[ 0.08382261, 0.34465584], [ 0.02057236, 0.24053624], [ 0.04930754, 0.1117427 ], [ 0.17041089, 0.20810834], [ 0.21667874, 0.2840511 ], [ 0.02325345, 0.33799925], [-0.06198866, 0.13532838], [ 0.084592 , 0.20711854], [-0.04323687, 0.17096342], [ 0.23759182, 0.16801538]], dtype=float32)
上記の分類子の中央にあるTransformerEncoder
は、 bert_encoder
です。
エンコーダーを調べると、同じ3つの入力に接続されたTransformer
レイヤーのスタックがわかります。
tf.keras.utils.plot_model(bert_encoder, show_shapes=True, dpi=48)
エンコーダの重みを復元する
構築されると、エンコーダはランダムに初期化されます。チェックポイントからエンコーダの重みを復元します。
checkpoint = tf.train.Checkpoint(model=bert_encoder)
checkpoint.restore(
os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fad4580ffd0>
オプティマイザを設定します
BERTは、ウェイトディケイを備えたAdamオプティマイザー(別名「 AdamW 」)を採用しています。また、最初に0からウォームアップし、次に0に減衰する学習率スケジュールを採用しています。
# Set up epochs and steps
epochs = 3
batch_size = 32
eval_batch_size = 32
train_data_size = len(glue_train_labels)
steps_per_epoch = int(train_data_size / batch_size)
num_train_steps = steps_per_epoch * epochs
warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)
# creates an optimizer with learning rate schedule
optimizer = nlp.optimization.create_optimizer(
2e-5, num_train_steps=num_train_steps, num_warmup_steps=warmup_steps)
これにより、学習率スケジュールが設定されたAdamWeightDecay
オプティマイザーが返されます。
type(optimizer)
official.nlp.optimization.AdamWeightDecay
オプティマイザーとそのスケジュールをカスタマイズする方法の例については、オプティマイザーのスケジュールの付録を参照してください。
モデルをトレーニングする
メトリックは精度であり、損失としてスパースカテゴリクロスエントロピーを使用します。
metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
bert_classifier.compile(
optimizer=optimizer,
loss=loss,
metrics=metrics)
bert_classifier.fit(
glue_train, glue_train_labels,
validation_data=(glue_validation, glue_validation_labels),
batch_size=32,
epochs=epochs)
Epoch 1/3 115/115 [==============================] - 26s 222ms/step - loss: 0.6151 - accuracy: 0.6611 - val_loss: 0.5462 - val_accuracy: 0.7451 Epoch 2/3 115/115 [==============================] - 24s 212ms/step - loss: 0.4447 - accuracy: 0.8010 - val_loss: 0.4150 - val_accuracy: 0.8309 Epoch 3/3 115/115 [==============================] - 24s 213ms/step - loss: 0.2830 - accuracy: 0.8964 - val_loss: 0.3697 - val_accuracy: 0.8480 <tensorflow.python.keras.callbacks.History at 0x7fad000ebda0>
次に、カスタムサンプルで微調整されたモデルを実行して、機能することを確認します。
いくつかの文のペアをエンコードすることから始めます。
my_examples = bert_encode(
glue_dict = {
'sentence1':[
'The rain in Spain falls mainly on the plain.',
'Look I fine tuned BERT.'],
'sentence2':[
'It mostly rains on the flat lands of Spain.',
'Is it working? This does not match.']
},
tokenizer=tokenizer)
モデルは、最初の例ではクラス1
「一致」を報告し、2番目の例ではクラス0
「不一致」を報告する必要があります。
result = bert_classifier(my_examples, training=False)
result = tf.argmax(result).numpy()
result
array([1, 0])
np.array(info.features['label'].names)[result]
array(['equivalent', 'not_equivalent'], dtype='<U14')
モデルを保存します
多くの場合、モデルのトレーニングの目的は、モデルを何かに使用することです。そのため、モデルをエクスポートしてから復元し、確実に機能するようにします。
export_dir='./saved_model'
tf.saved_model.save(bert_classifier, export_dir=export_dir)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: ./saved_model/assets INFO:tensorflow:Assets written to: ./saved_model/assets
reloaded = tf.saved_model.load(export_dir)
reloaded_result = reloaded([my_examples['input_word_ids'],
my_examples['input_mask'],
my_examples['input_type_ids']], training=False)
original_result = bert_classifier(my_examples, training=False)
# The results are (nearly) identical:
print(original_result.numpy())
print()
print(reloaded_result.numpy())
[[-0.95450354 1.1227685 ] [ 0.40344787 -0.58954155]] [[-0.95450354 1.1227684 ] [ 0.4034478 -0.5895414 ]]
付録
大規模なデータセットの再エンコード
このチュートリアルでは、わかりやすくするために、データセットをメモリに再エンコードしました。
これが可能だったのは、 glue/mrpc
が非常に小さいデータセットであるためです。より大きなデータセットを処理するために、 tf_models
ライブラリには、効率的なトレーニングのためにデータセットを処理および再エンコードするためのツールがいくつか含まれています。
最初のステップは、データセットのどの機能を変換する必要があるかを説明することです。
processor = nlp.data.classifier_data_lib.TfdsProcessor(
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2",
process_text_fn=bert.tokenization.convert_to_unicode)
次に、変換を適用して新しいTFRecordファイルを生成します。
# Set up output of training and evaluation Tensorflow dataset
train_data_output_path="./mrpc_train.tf_record"
eval_data_output_path="./mrpc_eval.tf_record"
max_seq_length = 128
batch_size = 32
eval_batch_size = 32
# Generate and save training data into a tf record file
input_meta_data = (
nlp.data.classifier_data_lib.generate_tf_record_from_data_file(
processor=processor,
data_dir=None, # It is `None` because data is from tfds, not local dir.
tokenizer=tokenizer,
train_data_output_path=train_data_output_path,
eval_data_output_path=eval_data_output_path,
max_seq_length=max_seq_length))
最後に、これらのTFRecordファイルからtf.data
入力パイプラインを作成します。
training_dataset = bert.run_classifier.get_dataset_fn(
train_data_output_path,
max_seq_length,
batch_size,
is_training=True)()
evaluation_dataset = bert.run_classifier.get_dataset_fn(
eval_data_output_path,
max_seq_length,
eval_batch_size,
is_training=False)()
結果のtf.data.Datasets
(features, labels)
keras.Model.fit
期待されるように(features, labels)
ペアをkeras.Model.fit
ます。
training_dataset.element_spec
({'input_word_ids': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), 'input_mask': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), 'input_type_ids': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None)}, TensorSpec(shape=(32,), dtype=tf.int32, name=None))
トレーニングと評価のためにtf.data.Datasetを作成します
データの読み込みを変更する必要がある場合は、ここに開始するためのコードがあります。
def create_classifier_dataset(file_path, seq_length, batch_size, is_training):
"""Creates input dataset from (tf)records files for train/eval."""
dataset = tf.data.TFRecordDataset(file_path)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
def decode_record(record):
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
}
return tf.io.parse_single_example(record, name_to_features)
def _select_data_from_record(record):
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
y = record['label_ids']
return (x, y)
dataset = dataset.map(decode_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
# Set up batch sizes
batch_size = 32
eval_batch_size = 32
# Return Tensorflow dataset
training_dataset = create_classifier_dataset(
train_data_output_path,
input_meta_data['max_seq_length'],
batch_size,
is_training=True)
evaluation_dataset = create_classifier_dataset(
eval_data_output_path,
input_meta_data['max_seq_length'],
eval_batch_size,
is_training=False)
training_dataset.element_spec
({'input_word_ids': TensorSpec(shape=(32, 128), dtype=tf.int64, name=None), 'input_mask': TensorSpec(shape=(32, 128), dtype=tf.int64, name=None), 'input_type_ids': TensorSpec(shape=(32, 128), dtype=tf.int64, name=None)}, TensorSpec(shape=(32,), dtype=tf.int64, name=None))
TFHub上のTFModelsBERT
TFHubからBERTモデルをすぐに入手できます。このhub.KerasLayer
上に分類ヘッドを追加するのは難しくありませんhub.KerasLayer
# Note: 350MB download.
import tensorflow_hub as hub
hub_model_name = "bert_en_uncased_L-12_H-768_A-12"
0733065400 The Hub encoder has 199 trainable variables
データのバッチでテスト実行します。
result = hub_encoder(
inputs=[glue_train['input_word_ids'][:10],
glue_train['input_mask'][:10],
glue_train['input_type_ids'][:10],],
training=False,
)
print("Pooled output shape:", result[0].shape)
print("Sequence output shape:", result[1].shape)
Pooled output shape: (10, 768) Sequence output shape: (10, 103, 768)
この時点で、分類ヘッドを自分で追加するのは簡単です。
bert_models.classifier_model
関数は、TensorFlowハブからエンコーダーに分類子を構築することもできます。
hub_classifier, hub_encoder = bert.bert_models.classifier_model(
# Caution: Most of `bert_config` is ignored if you pass a hub url.
bert_config=bert_config, hub_module_url=hub_url_bert, num_labels=2)
TFHubからこのモデルをロードすることの1つの欠点は、内部のkerasレイヤーの構造が復元されないことです。そのため、モデルを検査または変更することはより困難です。 TransformerEncoder
モデルは単一レイヤーになりました。
tf.keras.utils.plot_model(hub_classifier, show_shapes=True, dpi=64)
try:
tf.keras.utils.plot_model(hub_encoder, show_shapes=True, dpi=64)
assert False
except Exception as e:
print(f"{type(e).__name__}: {e}")
AttributeError: 'KerasLayer' object has no attribute 'layers'
低レベルのモデル構築
モデルの構築をさらに制御する必要がある場合は、以前に使用したclassifier_model
関数が、実際にはnlp.modeling.models.BertClassifier
クラスとnlp.modeling.models.BertClassifier
クラスの単なる薄いラッパーであることにnlp.modeling.networks.TransformerEncoder
してnlp.modeling.models.BertClassifier
。アーキテクチャの変更を開始した場合、事前にトレーニングされたチェックポイントをリロードすることが正しくないか、不可能である可能性があるため、最初から再トレーニングする必要があることに注意してください。
エンコーダーを作成します。
transformer_config = config_dict.copy()
# You need to rename a few fields to make this work:
transformer_config['attention_dropout_rate'] = transformer_config.pop('attention_probs_dropout_prob')
transformer_config['activation'] = tf_utils.get_activation(transformer_config.pop('hidden_act'))
transformer_config['dropout_rate'] = transformer_config.pop('hidden_dropout_prob')
transformer_config['initializer'] = tf.keras.initializers.TruncatedNormal(
stddev=transformer_config.pop('initializer_range'))
transformer_config['max_sequence_length'] = transformer_config.pop('max_position_embeddings')
transformer_config['num_layers'] = transformer_config.pop('num_hidden_layers')
transformer_config
{'hidden_size': 768, 'intermediate_size': 3072, 'num_attention_heads': 12, 'type_vocab_size': 2, 'vocab_size': 30522, 'attention_dropout_rate': 0.1, 'activation': <function official.modeling.activations.gelu.gelu(x)>, 'dropout_rate': 0.1, 'initializer': <tensorflow.python.keras.initializers.initializers_v2.TruncatedNormal at 0x7fac08046e10>, 'max_sequence_length': 512, 'num_layers': 12}
manual_encoder = nlp.modeling.networks.TransformerEncoder(**transformer_config)
重みを復元します。
checkpoint = tf.train.Checkpoint(model=manual_encoder)
checkpoint.restore(
os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fabefa596d8>
テスト実行:
result = manual_encoder(my_examples, training=True)
print("Sequence output shape:", result[0].shape)
print("Pooled output shape:", result[1].shape)
Sequence output shape: (2, 23, 768) Pooled output shape: (2, 768)
分類子でラップします。
manual_classifier = nlp.modeling.models.BertClassifier(
bert_encoder,
num_classes=2,
dropout_rate=transformer_config['dropout_rate'],
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
manual_classifier(my_examples, training=True).numpy()
array([[ 0.07863025, -0.02940944], [ 0.30274656, 0.27299827]], dtype=float32)
オプティマイザーとスケジュール
モデルのトレーニングに使用されるオプティマイザーは、 nlp.optimization.create_optimizer
関数を使用して作成されました。
optimizer = nlp.optimization.create_optimizer(
2e-5, num_train_steps=num_train_steps, num_warmup_steps=warmup_steps)
その高レベルのラッパーは、学習率のスケジュールとオプティマイザーを設定します。
ここで使用される基本学習率スケジュールは、トレーニングの実行中にゼロに直線的に減衰します。
epochs = 3
batch_size = 32
eval_batch_size = 32
train_data_size = len(glue_train_labels)
steps_per_epoch = int(train_data_size / batch_size)
num_train_steps = steps_per_epoch * epochs
decay_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=2e-5,
decay_steps=num_train_steps,
end_learning_rate=0)
plt.plot([decay_schedule(n) for n in range(num_train_steps)])
[<matplotlib.lines.Line2D at 0x7fabef5e69e8>]
これは、次に、トレーニングの最初の10%にわたって学習率を目標値まで直線的に増加させるWarmUp
スケジュールにラップされます。
warmup_steps = num_train_steps * 0.1
warmup_schedule = nlp.optimization.WarmUp(
initial_learning_rate=2e-5,
decay_schedule_fn=decay_schedule,
warmup_steps=warmup_steps)
# The warmup overshoots, because it warms up to the `initial_learning_rate`
# following the original implementation. You can set
# `initial_learning_rate=decay_schedule(warmup_steps)` if you don't like the
# overshoot.
plt.plot([warmup_schedule(n) for n in range(num_train_steps)])
[<matplotlib.lines.Line2D at 0x7fabef559630>]
次に、BERTモデル用に構成されたそのスケジュールを使用してnlp.optimization.AdamWeightDecay
を作成します。
optimizer = nlp.optimization.AdamWeightDecay(
learning_rate=warmup_schedule,
weight_decay_rate=0.01,
epsilon=1e-6,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])