บันทึกวันที่! Google I / O ส่งคืนวันที่ 18-20 พฤษภาคม ลงทะเบียนตอนนี้
หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

การจัดประเภทข้อความด้วย TensorFlow Lite Model Maker

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

ไลบรารี TensorFlow Lite Model Maker ช่วยลดความยุ่งยากในกระบวนการปรับและแปลงโมเดล TensorFlow เป็นข้อมูลอินพุตเฉพาะเมื่อปรับใช้โมเดลนี้สำหรับแอปพลิเคชัน ML บนอุปกรณ์

สมุดบันทึกนี้แสดงตัวอย่างแบบ end-to-end ที่ใช้ไลบรารี Model Maker เพื่อแสดงให้เห็นถึงการดัดแปลงและการแปลงรูปแบบการจัดหมวดหมู่ข้อความที่ใช้กันทั่วไปเพื่อจัดประเภทบทวิจารณ์ภาพยนตร์บนอุปกรณ์เคลื่อนที่ รูปแบบการจัดประเภทข้อความจะแบ่งประเภทข้อความเป็นประเภทที่กำหนดไว้ล่วงหน้า อินพุตควรเป็นข้อความที่ประมวลผลล่วงหน้าและผลลัพธ์คือความน่าจะเป็นของหมวดหมู่ ชุดข้อมูลที่ใช้ในบทแนะนำนี้คือบทวิจารณ์ภาพยนตร์ในเชิงบวกและเชิงลบ

ข้อกำหนดเบื้องต้น

ติดตั้งแพ็คเกจที่ต้องการ

ในการรันตัวอย่างนี้ให้ติดตั้งแพ็กเกจที่จำเป็นรวมถึงแพ็กเกจ Model Maker จากที่เก็บ GitHub

หากคุณเรียกใช้สมุดบันทึกนี้บน Colab คุณอาจเห็นข้อความแสดงข้อผิดพลาดเกี่ยวกับความเข้ากันไม่ได้ของเวอร์ชัน tensorflowjs และ tensorflow-hub ปลอดภัยที่จะเพิกเฉยต่อข้อผิดพลาดนี้เนื่องจากเราไม่ได้ใช้ tensorflowjs ในเวิร์กโฟลว์นี้

pip install -q tflite-model-maker

นำเข้าแพ็คเกจที่ต้องการ

import numpy as np
import os

from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker import TextClassifierDataLoader

import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')

ดาวน์โหลดข้อมูลการฝึกอบรมตัวอย่าง

ในบทช่วยสอนนี้เราจะใช้ SST-2 (Stanford Sentiment Treebank) ซึ่งเป็นหนึ่งในงานในเกณฑ์มาตรฐาน GLUE ประกอบด้วยบทวิจารณ์ภาพยนตร์ 67,349 บทสำหรับการฝึกอบรมและบทวิจารณ์ภาพยนตร์ 872 เรื่องสำหรับการทดสอบ ชุดข้อมูลมี 2 ชั้น ได้แก่ บทวิจารณ์ภาพยนตร์เชิงบวกและเชิงลบ

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')
Downloading data from https://dl.fbaipublicfiles.com/glue/data/SST-2.zip
7446528/7439277 [==============================] - 2s 0us/step

ชุดข้อมูล SST-2 ถูกจัดเก็บในรูปแบบ TSV ข้อแตกต่างระหว่าง TSV และ CSV คือว่า TSV ใช้แท็บ \t ตัวอักษรเป็นตัวคั่นแทนของเครื่องหมายจุลภาค , ในรูปแบบรูปแบบ CSV

นี่คือ 5 บรรทัดแรกของชุดข้อมูลการฝึกอบรม label = 0 หมายถึงค่าลบป้ายกำกับ = 1 หมายถึงค่าบวก

ประโยค ฉลาก
ซ่อนสารคัดหลั่งใหม่จากหน่วยผู้ปกครอง 0
ไม่มีปัญญามี แต่มุขที่ใช้แรงงาน 0
ที่รักตัวละครและสื่อถึงสิ่งที่สวยงามเกี่ยวกับธรรมชาติของมนุษย์ 1
ยังคงพึงพอใจอย่างเต็มที่ที่จะยังคงเหมือนเดิมตลอดไป 0
เกี่ยวกับความคิดโบราณที่เลวร้ายที่สุดในการแก้แค้นของคนขี้เบื่อที่ทีมผู้สร้างสามารถขุดลอกขึ้นมาได้ 0

ต่อไปเราจะโหลดชุดข้อมูลลงในดาต้าเฟรมของ Pandas และเปลี่ยนชื่อป้ายกำกับปัจจุบัน ( 0 และ 1 ) เป็นชื่อที่มนุษย์อ่านได้มากขึ้น ( negative และ positive ) และใช้สำหรับการฝึกโมเดล

import pandas as pd

def replace_label(original_file, new_file):
  # Load the original file to pandas. We need to specify the separator as
  # '\t' as the training data is stored in TSV format
  df = pd.read_csv(original_file, sep='\t')

  # Define how we want to change the label name
  label_map = {0: 'negative', 1: 'positive'}

  # Excute the label change
  df.replace({'label': label_map}, inplace=True)

  # Write the updated dataset to a new file
  df.to_csv(new_file)

# Replace the label name for both the training and test dataset. Then write the
# updated CSV dataset to the current folder.
replace_label(os.path.join(os.path.join(data_dir, 'train.tsv')), 'train.csv')
replace_label(os.path.join(os.path.join(data_dir, 'dev.tsv')), 'dev.csv')

เริ่มต้นอย่างรวดเร็ว

มีห้าขั้นตอนในการฝึกโมเดลการจัดประเภทข้อความ:

ขั้นตอนที่ 1. เลือกสถาปัตยกรรมแบบจำลองการจัดประเภทข้อความ

ที่นี่เราใช้คำโดยเฉลี่ยสถาปัตยกรรมแบบจำลองการฝังซึ่งจะสร้างแบบจำลองขนาดเล็กและรวดเร็วด้วยความแม่นยำที่เหมาะสม

spec = model_spec.get('average_word_vec')

Model Maker ยังรองรับสถาปัตยกรรมโมเดลอื่น ๆ เช่น BERT หากคุณสนใจที่จะเรียนรู้เกี่ยวกับสถาปัตยกรรมอื่น ๆ โปรดดูส่วน เลือกสถาปัตยกรรมแบบจำลองสำหรับตัวจำแนกข้อความ ด้านล่าง

ขั้นตอนที่ 2. โหลดข้อมูลการฝึกอบรมและการทดสอบจากนั้นประมวลผลล่วงหน้าตาม model_spec ระบุ

Model Maker สามารถรับข้อมูลอินพุตในรูปแบบ CSV เราจะโหลดชุดข้อมูลการฝึกอบรมและการทดสอบด้วยชื่อฉลากที่มนุษย์อ่านได้ซึ่งสร้างขึ้นก่อนหน้านี้

สถาปัตยกรรมแต่ละโมเดลต้องการข้อมูลอินพุตเพื่อประมวลผลในลักษณะเฉพาะ TextClassifierDataLoader อ่านข้อกำหนดจาก model_spec และเรียกใช้การประมวลผลล่วงหน้าที่จำเป็นโดยอัตโนมัติ

train_data = TextClassifierDataLoader.from_csv(
      filename='train.csv',
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename='dev.csv',
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      is_training=False)

ขั้นตอนที่ 3 ฝึกโมเดล TensorFlow ด้วยข้อมูลการฝึกอบรม

โมเดลการฝังคำโดยเฉลี่ยใช้ batch_size = 32 โดยค่าเริ่มต้น ดังนั้นคุณจะเห็นว่าต้องใช้ 2104 ขั้นตอนในการผ่าน 67,349 ประโยคในชุดข้อมูลการฝึกอบรม เราจะฝึกโมเดล 10 ยุคซึ่งหมายถึงการผ่านชุดข้อมูลการฝึก 10 ครั้ง

model = text_classifier.create(train_data, model_spec=spec, epochs=10)
Epoch 1/10
2104/2104 [==============================] - 7s 3ms/step - loss: 0.6830 - accuracy: 0.5595
Epoch 2/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.5781 - accuracy: 0.7091
Epoch 3/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4452 - accuracy: 0.7967
Epoch 4/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3921 - accuracy: 0.8253
Epoch 5/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3665 - accuracy: 0.8409
Epoch 6/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3516 - accuracy: 0.8478
Epoch 7/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3397 - accuracy: 0.8542
Epoch 8/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3332 - accuracy: 0.8622
Epoch 9/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3261 - accuracy: 0.8644
Epoch 10/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3216 - accuracy: 0.8662

ขั้นตอนที่ 4. ประเมินแบบจำลองด้วยข้อมูลการทดสอบ

หลังจากฝึกรูปแบบการจำแนกข้อความโดยใช้ประโยคในชุดข้อมูลการฝึกอบรมเราจะใช้ 872 ประโยคที่เหลือในชุดข้อมูลทดสอบเพื่อประเมินว่าแบบจำลองทำงานอย่างไรกับข้อมูลใหม่ที่ไม่เคยเห็นมาก่อน

เนื่องจากขนาดแบตช์เริ่มต้นคือ 32 จึงต้องใช้ 28 ขั้นตอนเพื่อผ่าน 872 ประโยคในชุดข้อมูลทดสอบ

loss, acc = model.evaluate(test_data)
28/28 [==============================] - 0s 2ms/step - loss: 0.5162 - accuracy: 0.8303

ขั้นตอนที่ 5. ส่งออกเป็นโมเดล TensorFlow Lite

เรามาส่งออกการจัดประเภทข้อความที่เราได้รับการฝึกฝนในรูปแบบ TensorFlow Lite เราจะระบุโฟลเดอร์ที่จะส่งออกโมเดล

คุณอาจเห็นคำเตือนเกี่ยวกับไฟล์ vocab.txt ไม่มีอยู่ในข้อมูลเมตา แต่สามารถละเว้นได้อย่างปลอดภัย

model.export(export_dir='average_word_vec')
Finished populating metadata and associated file to the model:
average_word_vec/model.tflite
The metadata json file has been saved to:
average_word_vec/model.json
The associated file that has been been packed to the model is:
['vocab.txt', 'labels.txt']
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_lite_support/metadata/python/metadata.py:344: UserWarning: File, 'vocab.txt', does not exsit in the metadata. But packing it to tflite model is still allowed.
  "tflite model is still allowed.".format(f))

คุณสามารถดาวน์โหลดไฟล์โมเดล TensorFlow Lite โดยใช้แถบด้านข้างทางซ้ายของ Colab ไปที่โฟลเดอร์ average_word_vec ตามที่เราระบุไว้ในพารามิเตอร์ export_dir ด้านบนคลิกขวาที่ไฟล์ model.tflite แล้วเลือก Download เพื่อดาวน์โหลดลงในเครื่องคอมพิวเตอร์ของคุณ

โมเดลนี้สามารถรวมเข้ากับแอพ Android หรือ iOS โดยใช้ NLClassifier API ของ ไลบรารีงาน TensorFlow Lite

ดู ตัวอย่างแอป TFLite Text Classification สำหรับรายละเอียดเพิ่มเติมเกี่ยวกับวิธีใช้โมเดลในแอปที่ใช้งานได้

หมายเหตุ 1: Android Studio Model Binding ยังไม่รองรับการจัดประเภทข้อความดังนั้นโปรดใช้ไลบรารีงาน TensorFlow Lite

หมายเหตุ 2: มีไฟล์ model.json อยู่ในโฟลเดอร์เดียวกันกับโมเดล TFLite ประกอบด้วยการแสดง JSON ของ ข้อมูลเมตาที่ รวมอยู่ในโมเดล TensorFlow Lite ข้อมูลเมตาของโมเดลช่วยให้ไลบรารีงาน TFLite รู้ว่าโมเดลทำอะไรและวิธีการประมวลผลข้อมูลล่วงหน้า / หลังการประมวลผลสำหรับโมเดล คุณไม่จำเป็นต้องดาวน์โหลดไฟล์ model.json เนื่องจากมีวัตถุประสงค์เพื่อให้ข้อมูลเท่านั้นและเนื้อหาของไฟล์นั้นอยู่ในไฟล์ TFLite แล้ว

หมายเหตุ 3: หากคุณฝึกโมเดลการจัดประเภทข้อความโดยใช้สถาปัตยกรรม MobileBERT หรือ BERT-Base คุณจะต้องใช้ BertNLClassifier API แทนเพื่อรวมโมเดลที่ผ่านการฝึกอบรมเข้ากับแอปบนอุปกรณ์เคลื่อนที่

ส่วนต่อไปนี้จะอธิบายถึงตัวอย่างทีละขั้นตอนเพื่อแสดงรายละเอียดเพิ่มเติม

เลือกสถาปัตยกรรมแบบจำลองสำหรับ Text Classifier

ออบเจ็กต์ model_spec แต่ละอ็อบเจ็กต์แสดงโมเดลเฉพาะสำหรับตัวจำแนกข้อความ ปัจจุบัน TensorFlow Lite Model Maker รองรับ MobileBERT โดยเฉลี่ยการฝังคำและโมเดล BERT-Base

รุ่นที่รองรับ ชื่อ model_spec คำอธิบายรุ่น ขนาดโมเดล
การฝังคำโดยเฉลี่ย 'average_word_vec' ค่าเฉลี่ยการฝังคำข้อความด้วยการเปิดใช้งาน RELU <1MB
MobileBERT 'mobilebert_classifier' เล็กกว่า 4.3 เท่าและเร็วกว่า BERT-Base 5.5 เท่าในขณะที่บรรลุผลการแข่งขันเหมาะสำหรับการใช้งานบนอุปกรณ์ 25MB พร้อมการหาปริมาณ
100MB w / o quantization
BERT- ฐาน 'bert_classifier' แบบจำลองมาตรฐาน BERT ที่ใช้กันอย่างแพร่หลายในงาน NLP 300MB

ในการเริ่มต้นอย่างรวดเร็วเราได้ใช้โมเดลการฝังคำโดยเฉลี่ย ลองเปลี่ยนมาใช้ MobileBERT เพื่อฝึกโมเดลที่มีความแม่นยำสูงขึ้น

mb_spec = model_spec.get('mobilebert_classifier')

โหลดข้อมูลการฝึกอบรม

คุณสามารถอัปโหลดชุดข้อมูลของคุณเองเพื่อทำงานผ่านบทช่วยสอนนี้ อัปโหลดชุดข้อมูลของคุณโดยใช้แถบด้านข้างทางซ้ายใน Colab

อัปโหลดไฟล์

หากคุณไม่ต้องการอัปโหลดชุดข้อมูลของคุณไปยังระบบคลาวด์คุณสามารถเรียกใช้ไลบรารีในเครื่องได้โดยทำตาม คำแนะนำ

เพื่อให้ง่ายเราจะนำชุดข้อมูล SST-2 ที่ดาวน์โหลดมาก่อนหน้านี้กลับมาใช้ใหม่ มาใช้เมธอด TestClassifierDataLoader.from_csv เพื่อโหลดข้อมูล

โปรดทราบว่าในขณะที่เราเปลี่ยนสถาปัตยกรรมโมเดลเราจะต้องโหลดการฝึกอบรมและทดสอบชุดข้อมูลซ้ำเพื่อใช้ตรรกะก่อนการประมวลผลใหม่

train_data = TextClassifierDataLoader.from_csv(
      filename='train.csv',
      text_column='sentence',
      label_column='label',
      model_spec=mb_spec,
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename='dev.csv',
      text_column='sentence',
      label_column='label',
      model_spec=mb_spec,
      is_training=False)

ไลบรารี Model Maker ยังรองรับ from_folder() การโหลดข้อมูล ถือว่าข้อมูลข้อความของคลาสเดียวกันอยู่ในไดเร็กทอรีย่อยเดียวกันและชื่อโฟลเดอร์ย่อยเป็นชื่อคลาส ไฟล์ข้อความแต่ละไฟล์มีตัวอย่างบทวิจารณ์ภาพยนตร์หนึ่งรายการ พารามิเตอร์ class_labels ใช้เพื่อระบุโฟลเดอร์ย่อย

ฝึกโมเดล TensorFlow

ฝึกโมเดลการจำแนกข้อความโดยใช้ข้อมูลการฝึกอบรม

model = text_classifier.create(train_data, model_spec=mb_spec, epochs=3)
Epoch 1/3
1403/1403 [==============================] - 309s 183ms/step - loss: 0.6354 - test_accuracy: 0.7444
Epoch 2/3
1403/1403 [==============================] - 244s 174ms/step - loss: 0.1467 - test_accuracy: 0.9465
Epoch 3/3
1403/1403 [==============================] - 245s 174ms/step - loss: 0.0833 - test_accuracy: 0.9727

ตรวจสอบโครงสร้างแบบจำลองโดยละเอียด

model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_word_ids (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, 128)]        0                                            
__________________________________________________________________________________________________
input_type_ids (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
hub_keras_layer_v1v2 (HubKerasL (None, 512)          24581888    input_word_ids[0][0]             
                                                                 input_mask[0][0]                 
                                                                 input_type_ids[0][0]             
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 512)          0           hub_keras_layer_v1v2[0][0]       
__________________________________________________________________________________________________
output (Dense)                  (None, 2)            1026        dropout_1[0][0]                  
==================================================================================================
Total params: 24,582,914
Trainable params: 24,582,914
Non-trainable params: 0
__________________________________________________________________________________________________

ประเมินแบบจำลอง

ประเมินแบบจำลองที่เราเพิ่งฝึกโดยใช้ข้อมูลการทดสอบและวัดค่าความสูญเสียและความแม่นยำ

loss, acc = model.evaluate(test_data)
28/28 [==============================] - 7s 44ms/step - loss: 0.3724 - test_accuracy: 0.9106

หาโมเดล

ในแอปพลิเคชัน ML บนอุปกรณ์จำนวนมากขนาดของโมเดลเป็นปัจจัยสำคัญ ดังนั้นขอแนะนำให้คุณใช้แบบจำลองเชิงปริมาณเพื่อทำให้โมเดลมีขนาดเล็กลงและอาจทำงานได้เร็วขึ้น Model Maker จะใช้รูปแบบการหาปริมาณที่แนะนำสำหรับสถาปัตยกรรมแต่ละรุ่นโดยอัตโนมัติ แต่คุณสามารถปรับแต่งการกำหนดค่าเชิงปริมาณได้ดังต่อไปนี้

config = configs.QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])
config.experimental_new_quantizer = True

ส่งออกเป็นแบบจำลอง TensorFlow Lite

แปลงโมเดลที่ผ่านการฝึกอบรมเป็นรูปแบบโมเดล TensorFlow Lite พร้อมด้วย ข้อมูลเมตา เพื่อให้คุณสามารถใช้ในแอปพลิเคชัน ML บนอุปกรณ์ได้ในภายหลัง ไฟล์เลเบลและไฟล์คำศัพท์ถูกฝังอยู่ในข้อมูลเมตา ชื่อไฟล์ TFLite เริ่มต้นคือ model.tflite

model.export(export_dir='mobilebert/', quantization_config=config)
Finished populating metadata and associated file to the model:
mobilebert/model.tflite
The metadata json file has been saved to:
mobilebert/model.json
The associated file that has been been packed to the model is:
['vocab.txt', 'labels.txt']

ไฟล์โมเดล TensorFlow Lite สามารถรวมเข้ากับแอพมือถือโดยใช้ BertNLClassifier API ใน ไลบรารีงาน TensorFlow Lite โปรดทราบว่าสิ่งนี้ แตกต่าง จาก NLClassifier API ที่ใช้ในการรวมการจัดประเภทข้อความที่ได้รับการฝึกฝนกับสถาปัตยกรรมแบบจำลองเวกเตอร์คำเฉลี่ย

รูปแบบการส่งออกอาจเป็นอย่างใดอย่างหนึ่งหรือรายการต่อไปนี้:

  • ExportFormat.TFLITE
  • ExportFormat.LABEL
  • ExportFormat.VOCAB
  • ExportFormat.SAVED_MODEL

โดยค่าเริ่มต้นจะส่งออกเฉพาะไฟล์โมเดล TensorFlow Lite ที่มีข้อมูลเมตาของโมเดล คุณยังสามารถเลือกที่จะส่งออกไฟล์อื่น ๆ ที่เกี่ยวข้องกับโมเดลเพื่อการตรวจสอบที่ดีขึ้น ตัวอย่างเช่นการส่งออกเฉพาะไฟล์ label และไฟล์คำศัพท์ดังต่อไปนี้:

model.export(export_dir='mobilebert/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])

คุณสามารถประเมินรูปแบบ TFLite กับ evaluate_tflite วิธีการวัดความถูกต้อง การแปลงโมเดล TensorFlow ที่ได้รับการฝึกอบรมเป็นรูปแบบ TFLite และใช้การวัดปริมาณอาจส่งผลต่อความแม่นยำดังนั้นจึงขอแนะนำให้ประเมินความถูกต้องของโมเดล TFLite ก่อนการปรับใช้

accuracy = model.evaluate_tflite('mobilebert/model.tflite', test_data)
print('TFLite model accuracy: ', accuracy)
TFLite model accuracy:  {'accuracy': 0.9105504587155964}

การใช้งานขั้นสูง

ฟังก์ชัน create เป็นฟังก์ชันไดรเวอร์ที่ไลบรารี Model Maker ใช้ในการสร้างโมเดล พารามิเตอร์ model_spec กำหนดข้อกำหนดของโมเดล AverageWordVecModelSpec และ BertClassifierModelSpec เรียนได้รับการสนับสนุน ฟังก์ชัน create ประกอบด้วยขั้นตอนต่อไปนี้:

  1. สร้างโมเดลสำหรับตัวจำแนกข้อความตาม model_spec
  2. ฝึกโมเดลลักษณนาม epochs เริ่มต้นและขนาดชุดเริ่มต้นจะถูกกำหนดโดย default_training_epochs และ default_batch_size ตัวแปรใน model_spec วัตถุ

ส่วนนี้ครอบคลุมหัวข้อการใช้งานขั้นสูงเช่นการปรับโมเดลและพารามิเตอร์การฝึกอบรม

ปรับแต่งไฮเปอร์พารามิเตอร์ของโมเดล MobileBERT

พารามิเตอร์ของโมเดลที่คุณสามารถปรับได้ ได้แก่ :

  • seq_len : ความยาวของลำดับที่จะป้อนลงในโมเดล
  • initializer_range : ค่าเบี่ยงเบนมาตรฐานของ truncated_normal_initializer สำหรับการเริ่มต้นเมทริกซ์น้ำหนักทั้งหมด
  • trainable : บูลีนที่ระบุว่าเลเยอร์ที่ฝึกไว้ล่วงหน้าสามารถฝึกได้หรือไม่

พารามิเตอร์ไปป์ไลน์การฝึกอบรมที่คุณสามารถปรับได้ ได้แก่ :

  • model_dir : ตำแหน่งของไฟล์จุดตรวจโมเดล หากไม่ได้ตั้งค่าไว้ระบบจะใช้ไดเร็กทอรีชั่วคราว
  • dropout_rate : อัตราการออกกลางคัน
  • learning_rate : อัตราการเรียนรู้เริ่มต้นสำหรับ Adam Optimizer
  • tpu : ที่อยู่ TPU เพื่อเชื่อมต่อ

ตัวอย่างเช่นคุณสามารถตั้งค่า seq_len=256 (ค่าเริ่มต้นคือ 128) สิ่งนี้ช่วยให้โมเดลสามารถจัดประเภทข้อความที่ยาวขึ้นได้

new_model_spec = model_spec.get('mobilebert_classifier')
new_model_spec.seq_len = 256

ปรับแต่งพารามิเตอร์ไฮเปอร์พารามิเตอร์แบบจำลองการฝังคำโดยเฉลี่ย

คุณสามารถปรับโครงสร้างพื้นฐานของโมเดลเช่น wordvec_dim และตัวแปร seq_len ในคลาส AverageWordVecModelSpec

ตัวอย่างเช่นคุณสามารถฝึกโมเดลด้วย wordvec_dim มีค่ามากขึ้น โปรดทราบว่าคุณต้องสร้าง model_spec ใหม่หากคุณแก้ไขโมเดล

new_model_spec = model_spec.AverageWordVecModelSpec(wordvec_dim=32)

รับข้อมูลที่ประมวลผลล่วงหน้า

new_train_data = TextClassifierDataLoader.from_csv(
      filename='train.csv',
      text_column='sentence',
      label_column='label',
      model_spec=new_model_spec,
      is_training=True)

ฝึกโมเดลใหม่.

model = text_classifier.create(new_train_data, model_spec=new_model_spec)
Epoch 1/3
2104/2104 [==============================] - 8s 4ms/step - loss: 0.6846 - accuracy: 0.5581
Epoch 2/3
2104/2104 [==============================] - 6s 3ms/step - loss: 0.5710 - accuracy: 0.7086
Epoch 3/3
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4352 - accuracy: 0.8030

ปรับพารามิเตอร์การฝึกอบรม

คุณยังสามารถปรับค่าพารามิเตอร์การฝึกอบรมเช่น epochs และขนาดแบท batch_size ที่ส่งผลต่อความแม่นยำของโมเดล ตัวอย่างเช่น

  • epochs : epochs อื่น ๆ สามารถบรรลุความแม่นยำที่ดีขึ้น แต่อาจนำไปสู่การติดตั้งมากเกินไป
  • batch_size : จำนวนตัวอย่างที่จะใช้ในขั้นตอนการฝึกอบรมเดียว

ตัวอย่างเช่นคุณสามารถฝึกฝนกับยุคต่างๆได้มากขึ้น

model = text_classifier.create(new_train_data, model_spec=new_model_spec, epochs=20)
Epoch 1/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.6847 - accuracy: 0.5557
Epoch 2/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.5682 - accuracy: 0.7170
Epoch 3/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4318 - accuracy: 0.8025
Epoch 4/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3855 - accuracy: 0.8306
Epoch 5/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3628 - accuracy: 0.8459
Epoch 6/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3494 - accuracy: 0.8529
Epoch 7/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3380 - accuracy: 0.8582
Epoch 8/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3311 - accuracy: 0.8628
Epoch 9/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.3231 - accuracy: 0.8667
Epoch 10/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3184 - accuracy: 0.8688
Epoch 11/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3149 - accuracy: 0.8716
Epoch 12/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.3111 - accuracy: 0.8732
Epoch 13/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.3067 - accuracy: 0.8725
Epoch 14/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3028 - accuracy: 0.8753
Epoch 15/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3014 - accuracy: 0.8759
Epoch 16/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.2984 - accuracy: 0.8776
Epoch 17/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.2968 - accuracy: 0.8793
Epoch 18/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.2936 - accuracy: 0.8803
Epoch 19/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.2925 - accuracy: 0.8801
Epoch 20/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.2913 - accuracy: 0.8828

ประเมินแบบจำลองที่ได้รับการฝึกอบรมใหม่ด้วย 20 ยุคการฝึกอบรม

new_test_data = TextClassifierDataLoader.from_csv(
      filename='dev.csv',
      text_column='sentence',
      label_column='label',
      model_spec=new_model_spec,
      is_training=False)

loss, accuracy = model.evaluate(new_test_data)
28/28 [==============================] - 0s 2ms/step - loss: 0.4962 - accuracy: 0.8337

เปลี่ยนสถาปัตยกรรมแบบจำลอง

คุณสามารถเปลี่ยนโมเดลได้โดยเปลี่ยน model_spec ต่อไปนี้แสดงวิธีการเปลี่ยนเป็น BERT-Base model

เปลี่ยน model_spec เป็น BERT-Base model สำหรับตัวจำแนกข้อความ

spec = model_spec.get('bert_classifier')

ขั้นตอนที่เหลือเหมือนเดิม