עזרה להגן על שונית המחסום הגדולה עם TensorFlow על Kaggle הצטרפו אתגר

סיווג טקסט בעזרת TensorFlow Lite Model Maker

צפה ב- TensorFlow.org הפעל ב- Google Colab צפה במקור ב- GitHub הורד מחברת

ספריית Maker דגם לייט TensorFlow מפשטת את תהליך התאמת המרת מודל TensorFlow נתון קלט מסוימים כאשר פריסת המודל הזה עבור יישומים על-תקן ML.

מחברת זו מציגה דוגמה מקצה לקצה שמשתמשת בספריית Model Maker כדי להמחיש את ההתאמה וההמרה של מודל סיווג טקסט נפוץ לסיווג ביקורות סרטים במכשיר נייד. מודל סיווג הטקסט מסווג טקסט לקטגוריות מוגדרות מראש. התשומות צריכות להיות טקסט מעובד מראש והתפוקות הן ההסתברות של הקטגוריות. מערך הנתונים המשמש במדריך זה הוא ביקורות סרטים חיוביות ושליליות.

תנאים מוקדמים

התקן את החבילות הנדרשות

כדי להפעיל בדוגמה זו, להתקין את החבילות הנדרשות, כולל חבילת Maker דגם מן ריפו GitHub .

pip install -q tflite-model-maker

ייבא את החבילות הנדרשות.

import numpy as np
import os

from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.text_classifier import AverageWordVecSpec
from tflite_model_maker.text_classifier import DataLoader

import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release
  PkgResourcesDeprecationWarning,
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/numba/core/errors.py:168: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
  warnings.warn(msg)

הורד את נתוני ההדרכה לדוגמא.

במדריך זה, נשתמש SST-2 (סנטימנט סטנפורד Treebank) שהיא אחת המשימות של דבק benchmark. הוא מכיל 67,349 סקירות סרטים להכשרה ו 872 סקירות סרטים לבדיקה. למערך הנתונים שני סוגים: ביקורות סרטים חיוביות ושליליות.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')
Downloading data from https://dl.fbaipublicfiles.com/glue/data/SST-2.zip
7446528/7439277 [==============================] - 0s 0us/step
7454720/7439277 [==============================] - 0s 0us/step

מערך הנתונים SST-2 מאוחסן בפורמט TSV. ההבדל היחיד בין TSV ו- CSV הוא כי TSV משתמשת כרטיסיה \t אופי כתוחם שלה במקום פסיק , בפורמט CSV.

להלן 5 השורות הראשונות של מערך ההדרכה. תווית = 0 פירושה שלילי, תווית = 1 פירושה חיובי.

משפט תווית
להסתיר הפרשות חדשות מיחידות ההורים 0
אינו מכיל שנינות, רק גאגים מאומצים 0
שאוהב את הדמויות שלה ומעביר משהו יפה למדי בטבע האדם 1
נשאר מרוצה לחלוטין להישאר אותו דבר לאורך כל הדרך 0
בקלישאות הנקמה של החנונים הגרועות ביותר שיוצרי הסרטים יכולים לחפור 0

הבא, נוכל לטעון את הנתונים לתוך dataframe פנדה ולשנות את שמות התווית הנוכחית ( 0 ו 1 ) על יותר מאלה הקריא ( negative ו positive ) ולהשתמש בהם לצורך אימונים מודל.

import pandas as pd

def replace_label(original_file, new_file):
  # Load the original file to pandas. We need to specify the separator as
  # '\t' as the training data is stored in TSV format
  df = pd.read_csv(original_file, sep='\t')

  # Define how we want to change the label name
  label_map = {0: 'negative', 1: 'positive'}

  # Excute the label change
  df.replace({'label': label_map}, inplace=True)

  # Write the updated dataset to a new file
  df.to_csv(new_file)

# Replace the label name for both the training and test dataset. Then write the
# updated CSV dataset to the current folder.
replace_label(os.path.join(os.path.join(data_dir, 'train.tsv')), 'train.csv')
replace_label(os.path.join(os.path.join(data_dir, 'dev.tsv')), 'dev.csv')

התחלה מהירה

ישנם חמישה שלבים להכשרת מודל סיווג טקסט:

שלב 1. בחר ארכיטקטורת מודל סיווג טקסט.

כאן אנו משתמשים במילה ממוצעת הטמעה של ארכיטקטורת מודלים, שתייצר מודל קטן ומהיר עם דיוק הגון.

spec = model_spec.get('average_word_vec')

מכונת הדגם תומכת גם ארכיטקטורות מודל אחרות כגון ברט . אם אתם מעוניינים ללמוד על ארכיטקטורה אחרת, לראות את בחר אדריכלות מודל מסווג טקסט המופיע בהמשך.

שלב 2. לטעון את הנתונים הכשרה המבחן, אז preprocess אותם על פי המבוקש model_spec .

Model Maker יכול לקחת נתוני קלט בפורמט CSV. אנו נטען את מערך ההדרכה והבדיקה בשם התווית הקריאה לאדם שנוצרו קודם לכן.

כל ארכיטקטורת מודל דורשת עיבוד נתוני קלט בצורה מסוימת. DataLoader קורא את הדרישה מן model_spec ומבצעת החלו בעיבוד הנדרשים באופן אוטומטי.

train_data = DataLoader.from_csv(
      filename='train.csv',
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      is_training=True)
test_data = DataLoader.from_csv(
      filename='dev.csv',
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      is_training=False)

שלב 3. הכשיר את מודל TensorFlow עם נתוני האימון.

השימוש במודל ממוצע מילת הטבעת batch_size = 32 כברירת מחדל. לכן תראה שצריך 2104 צעדים כדי לעבור על 67,349 המשפטים במערך האימונים. נלמד את המודל ל -10 עידנים, כלומר לעבור על מערך האימונים 10 פעמים.

model = text_classifier.create(train_data, model_spec=spec, epochs=10)
Epoch 1/10
2104/2104 [==============================] - 7s 3ms/step - loss: 0.6777 - accuracy: 0.5657
Epoch 2/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.5666 - accuracy: 0.7200
Epoch 3/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4484 - accuracy: 0.7955
Epoch 4/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3982 - accuracy: 0.8262
Epoch 5/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3735 - accuracy: 0.8387
Epoch 6/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3595 - accuracy: 0.8476
Epoch 7/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3476 - accuracy: 0.8541
Epoch 8/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3379 - accuracy: 0.8593
Epoch 9/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3329 - accuracy: 0.8628
Epoch 10/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3258 - accuracy: 0.8646

שלב 4. הערך את המודל עם נתוני הבדיקה.

לאחר אימון מודל סיווג הטקסט באמצעות המשפטים במערך האימונים, נשתמש בשאר 872 המשפטים במערך הבדיקה כדי להעריך את ביצועי המודל מול נתונים חדשים שמעולם לא ראה.

מכיוון שגודל האצווה המוגדר כברירת מחדל הוא 32, יידרשו 28 צעדים לעבור את 872 המשפטים במערך הנתונים של הבדיקה.

loss, acc = model.evaluate(test_data)
28/28 [==============================] - 0s 2ms/step - loss: 0.5149 - accuracy: 0.8268

שלב 5. ייצא כדגם TensorFlow Lite.

בואו לייצא את סיווג הטקסט שהכשרנו בפורמט TensorFlow Lite. נפרט איזו תיקיה לייצא את הדגם. כברירת מחדל, מודל ה- TFLite הצף מיוצא עבור ארכיטקטורת המודל הטמעה ממוצעת של מילים.

model.export(export_dir='average_word_vec')
2021-11-02 12:53:52.418162: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
2021-11-02 12:53:52.887037: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-11-02 12:53:52.887073: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.

ניתן להוריד את קובץ הדגם של TensorFlow Lite באמצעות סרגל הצד השמאלי של Colab. לך אל average_word_vec התיקייה כפי שאנו המפורטים export_dir פרמטר מעל, לחצו לחיצה ימנית על model.tflite הקובץ ובחר Download כדי להוריד אותו למחשב המקומי שלך.

ניתן לשלב מודל זה לתוך אנדרואיד או אפליקצית iOS באמצעות API NLClassifier של הספרייה המשימה לייט TensorFlow .

עיין האפליקציה לדוגמא סיווג TFLite טקסט לפרטים נוספים על איך המודל משמש באפליקציה פועלת.

הערה 1: כריכת דגמי Android Studio אינה תומכת עדיין בסיווג טקסט, לכן אנא השתמש בספריית המשימות של TensorFlow Lite.

הערה 2: ישנו model.json הקובץ באותה תיקייה עם המודל TFLite. הוא מכיל את הייצוג JSON של metadata ארוזות בתוך המודל לייט TensorFlow. מטא נתונים של מודלים עוזרים לספריית המשימות של TFLite לדעת מה המודל עושה וכיצד לעבד/לעבד נתונים מראש עבור המודל. אתה לא צריך להוריד את model.json הקובץ כפי שהוא רק לצורך הסברה התוכן שלה הוא כבר בתוך קובץ TFLite.

הערה 3: אם אתה מתאמן מודל סיווג טקסט באמצעות MobileBERT או ברט-בסיס ארכיטקטורה, תצטרך להשתמש BertNLClassifier API במקום לשלב את המודל מודרך לתוך יישום נייד.

החלקים הבאים עוברים על הדוגמה צעד אחר צעד כדי להציג פרטים נוספים.

בחר ארכיטקטורת מודל עבור מסווג טקסט

כל model_spec אובייקט מייצג מודל ספציפי עבור מסווג הטקסט. מכונת TensorFlow דגם לייט תומכת כיום MobileBERT , שיבוצי מילת מיצוע ואת ברט-בסיס מודלים.

דגם נתמך שם model_spec תיאור דגם גודל דגם
הטמעה ממוצעת של מילים 'word_vec_ ממוצע' ממוצע הטמעת מילת טקסט עם הפעלת RELU. <1MB
MobileBERT 'mobilebert_classifier' קטן פי 4.3 ו -5.5 פעמים מהר יותר מבסיס BERT תוך השגת תוצאות תחרותיות, המתאימות ליישומים במכשיר. 25MB w/ quantization
100MB ללא כימות
בסיס BERT 'bert_classifier' מודל BERT סטנדרטי הנמצא בשימוש נרחב במשימות NLP. 300MB

בהתחלה המהירה השתמשנו במודל הטמעה ממוצע של מילים. מתג ב"בואו MobileBERT לאמן מודל עם דיוק גבוה.

mb_spec = model_spec.get('mobilebert_classifier')

טען נתוני אימון

תוכל להעלות מערך נתונים משלך לעבודה באמצעות הדרכה זו. העלה את מערך הנתונים שלך באמצעות סרגל הצד השמאלי ב- Colab.

העלה קובץ

אם אתה מעדיף שלא להעלות את מערך נתון לענן, אתה גם יכול מקומית להפעיל את הספרייה על ידי ביצוע המדריך .

כדי לשמור על הפשטות, נשתמש מחדש במערך הנתונים של SST-2 שהורד קודם לכן. בואו להשתמש DataLoader.from_csv שיטת לטעון את הנתונים.

שים לב שככל ששינינו את ארכיטקטורת המודל, יהיה עלינו לטעון מחדש את מערך ההכשרה והבדיקה כדי ליישם את ההיגיון החדש לעיבוד מראש.

train_data = DataLoader.from_csv(
      filename='train.csv',
      text_column='sentence',
      label_column='label',
      model_spec=mb_spec,
      is_training=True)
test_data = DataLoader.from_csv(
      filename='dev.csv',
      text_column='sentence',
      label_column='label',
      model_spec=mb_spec,
      is_training=False)

ספריית Maker הדגם תומכת גם from_folder() שיטת נתון עומס. הוא מניח כי נתוני הטקסט של אותה מחלקה נמצאים באותו ספריית המשנה וכי שם תיקיית המשנה הוא שם המחלקה. כל קובץ טקסט מכיל דוגמא אחת לסקירת סרטים. class_labels פרמטר משמש כדי לציין אילו תיקיות המשנה.

הכשרת דגם TensorFlow

לאמן מודל סיווג טקסט באמצעות נתוני האימון.

model = text_classifier.create(train_data, model_spec=mb_spec, epochs=3)
Epoch 1/3
1403/1403 [==============================] - 322s 193ms/step - loss: 0.3882 - test_accuracy: 0.8464
Epoch 2/3
1403/1403 [==============================] - 257s 183ms/step - loss: 0.1303 - test_accuracy: 0.9534
Epoch 3/3
1403/1403 [==============================] - 257s 183ms/step - loss: 0.0759 - test_accuracy: 0.9753

בחן את מבנה המודל המפורט.

model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_word_ids (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, 128)]        0                                            
__________________________________________________________________________________________________
input_type_ids (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
hub_keras_layer_v1v2 (HubKerasL (None, 512)          24581888    input_word_ids[0][0]             
                                                                 input_mask[0][0]                 
                                                                 input_type_ids[0][0]             
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 512)          0           hub_keras_layer_v1v2[0][0]       
__________________________________________________________________________________________________
output (Dense)                  (None, 2)            1026        dropout_1[0][0]                  
==================================================================================================
Total params: 24,582,914
Trainable params: 24,582,914
Non-trainable params: 0
__________________________________________________________________________________________________

להעריך את המודל

הערך את המודל שאותו הכשרנו באמצעות נתוני הבדיקה ומדוד את ערך ההפסד והדיוק.

loss, acc = model.evaluate(test_data)
28/28 [==============================] - 7s 48ms/step - loss: 0.3859 - test_accuracy: 0.9071

ייצא כדגם TensorFlow Lite

המרת המודל מאומן לפורמט מודל TensorFlow לייט עם מטה , כך שאתה יכול להשתמש מאוחר יותר ביישום ML על-התקן. קובץ התווית וקובץ ה- vocab מוטמעים במטא נתונים. שם הקובץ ברירת המחדל TFLite הוא model.tflite .

ביישומי ML רבים במכשיר, גודל הדגם הוא גורם חשוב. לכן, מומלץ ליישם את כמות המודל בכדי להקטין אותו ולרצות לפעול מהר יותר. ברירת המחדל של טכניקת הכימות שלאחר האימון היא כימות טווח דינמי עבור דגמי BERT ו- MobileBERT.

model.export(export_dir='mobilebert/')
2021-11-02 13:09:19.810703: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-11-02 13:09:19.810749: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
2021-11-02 13:09:19.810755: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored change_concat_input_ranges.

קובץ מודל לייט TensorFlow ניתן לשלב באפליקציה לנייד באמצעות API BertNLClassifier ב Library המשימות לייט TensorFlow . שים לב כי זה שונה NLClassifier API המשמש לשלב את סיווג הטקסט מאומן עם ארכיטקטורת מודל וקטור מילה הממוצעת.

תבניות הייצוא יכולות להיות אחת או רשימה של הדברים הבאים:

כברירת מחדל, הוא מייצא רק את קובץ הדגם של TensorFlow Lite המכיל את מטא הנתונים של הדגם. תוכל גם לבחור לייצא קבצים אחרים הקשורים למודל לבחינה טובה יותר. לדוגמה, ייצוא רק של קובץ התוויות וקובץ ה- vocab כדלקמן:

model.export(export_dir='mobilebert/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])

אתה יכול להעריך את מודל TFLite עם evaluate_tflite שיטה למדידת הדיוק שלה. המרת מודל TensorFlow המאומן לפורמט TFLite והחלת כימות יכולה להשפיע על דיוקו ולכן מומלץ להעריך את דיוק דגם TFLite לפני הפריסה.

accuracy = model.evaluate_tflite('mobilebert/model.tflite', test_data)
print('TFLite model accuracy: ', accuracy)
TFLite model accuracy:  {'accuracy': 0.9002293577981652}

שימוש מתקדם

create פונקציה היא פונקצית הנהג כי שימושי ספריית Maker הדגם ליצור מודלים. model_spec פרמטר המגדיר את המפרט מודל. AverageWordVecSpec ו BertClassifierSpec הכיתות נתמכות כרגע. create פונקציה מורכבת מהשלבים הבאים:

  1. יצירת מודל עבור מסווג הטקסט על פי model_spec .
  2. מאמן את דגם המסווג. התקופות ברירת המחדל ואת גודל המנה ברירת המחדל נקבעים על ידי default_training_epochs ו default_batch_size המשתנים model_spec האובייקט.

חלק זה עוסק בנושאי שימוש מתקדמים כמו התאמת המודל והפרפרמטרים של האימון.

התאם אישית את היפרפרמטרים של מודל MobileBERT

פרמטרי המודל שאתה יכול להתאים הם:

  • seq_len : אורך של רצף להזנה לתוך המודל.
  • initializer_range : סטיית התקן של truncated_normal_initializer מאתחל את כל מטריצות משקל.
  • trainable : בוליאני המציין אם שכבת מאומן מראש היא שאפשר לאלף.

הפרמטרים של צינור ההדרכה שאתה יכול להתאים הם:

  • model_dir : המיקום של הקבצים במחסום מודל. אם לא הוגדר, ישמש ספרייה זמנית.
  • dropout_rate : שיעור הנשירה.
  • learning_rate : שיעור הלמידה הראשוני עבור ייעול האדם.
  • tpu : כתובת TPU להתחבר.

למשל, אתה יכול להגדיר את seq_len=256 (ברירת המחדל היא 128). זה מאפשר לדגם לסווג טקסט ארוך יותר.

new_model_spec = model_spec.get('mobilebert_classifier')
new_model_spec.seq_len = 256

התאם אישית את היפרפרמטרים של מודל הטמעת המילה הממוצעת

ניתן להתאים את התשתית למודל כמו wordvec_dim ואת seq_len המשתנים AverageWordVecSpec בכיתה.

לדוגמה, אתה יכול לאמן את המודל עם ערך גדול יותר של wordvec_dim . שים לב, אתה חייב לבנות חדש model_spec אם תשנה את המודל.

new_model_spec = AverageWordVecSpec(wordvec_dim=32)

קבל את הנתונים המעובדים מראש.

new_train_data = DataLoader.from_csv(
      filename='train.csv',
      text_column='sentence',
      label_column='label',
      model_spec=new_model_spec,
      is_training=True)

לאמן את הדגם החדש.

model = text_classifier.create(new_train_data, model_spec=new_model_spec)
Epoch 1/3
2104/2104 [==============================] - 8s 4ms/step - loss: 0.6545 - accuracy: 0.6055
Epoch 2/3
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4822 - accuracy: 0.7712
Epoch 3/3
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3993 - accuracy: 0.8217

כוונן את היפרפרמטרים של האימון

ניתן גם לכוונן את hyperparameters אימונים כמו epochs ו batch_size המשפיעים על דיוק המודל. לדוגמה,

  • epochs : יותר מתקופות יכולות להשיג דיוק טוב יותר, אבל עלולה להוביל overfitting.
  • batch_size : מספר דגימות להשתמש בצעד אחד האימונים.

לדוגמה, אתה יכול להתאמן עם עידנים נוספים.

model = text_classifier.create(new_train_data, model_spec=new_model_spec, epochs=20)
Epoch 1/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.6595 - accuracy: 0.5994
Epoch 2/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4904 - accuracy: 0.7655
Epoch 3/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4036 - accuracy: 0.8177
Epoch 4/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3690 - accuracy: 0.8392
Epoch 5/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3516 - accuracy: 0.8500
Epoch 6/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3391 - accuracy: 0.8569
Epoch 7/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3301 - accuracy: 0.8608
Epoch 8/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3237 - accuracy: 0.8654
Epoch 9/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3190 - accuracy: 0.8674
Epoch 10/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3151 - accuracy: 0.8703
Epoch 11/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3103 - accuracy: 0.8725
Epoch 12/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3074 - accuracy: 0.8737
Epoch 13/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3036 - accuracy: 0.8757
Epoch 14/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3015 - accuracy: 0.8772
Epoch 15/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3001 - accuracy: 0.8773
Epoch 16/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.2981 - accuracy: 0.8779
Epoch 17/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.2954 - accuracy: 0.8797
Epoch 18/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.2939 - accuracy: 0.8798
Epoch 19/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.2924 - accuracy: 0.8809
Epoch 20/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.2903 - accuracy: 0.8822

העריכו את המודל שעברו הכשרה חדשה עם 20 תקופות אימון.

new_test_data = DataLoader.from_csv(
      filename='dev.csv',
      text_column='sentence',
      label_column='label',
      model_spec=new_model_spec,
      is_training=False)

loss, accuracy = model.evaluate(new_test_data)
28/28 [==============================] - 0s 2ms/step - loss: 0.5021 - accuracy: 0.8291

שנה את ארכיטקטורת המודל

אתה יכול לשנות את המודל על ידי שינוי model_spec . להלן מראה כיצד לשנות למודל BERT-Base.

שנו את model_spec למודל ברט-הבסיס מסווג הטקסט.

spec = model_spec.get('bert_classifier')

השלבים הנותרים זהים.

התאם אישית את הכמות שלאחר האימון במודל TensorFlow Lite

קוונטיזציה פוסט-ההכשרה הוא טכניקת מרה שיכולה להקטין את גודל מודל ואת חביון היסק, תוך שיפור מעבד מאיץ חומרת היקש מהיר, עם שפלה קטנה ב דיוק מודל. לפיכך, הוא נמצא בשימוש נרחב לייעל את המודל.

ספריית Model Maker מיישמת טכנולוגיית כימות ברירת מחדל לאחר אימון בעת ​​ייצוא המודל. אם ברצונך להתאים אישית קוונטיזציה שלאחר אימון, מכונת הדגם תומכת אפשרויות קוונטיזציה שלאחר אימון מרובות באמצעות QuantizationConfig גם כן. ניקח לדוגמה את הכמות של float16. ראשית, הגדר את תצורת הכמות.

config = QuantizationConfig.for_float16()

לאחר מכן אנו מייצאים את דגם TensorFlow Lite עם תצורה כזו.

model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)

קרא עוד

אתה יכול לקרוא שלנו סיווג טקסט למשל ללמוד פרטים טכניים. למידע נוסף, עיין ב: