Bu sayfa, Cloud Translation API ile çevrilmiştir.
Switch to English

TensorFlow Lite Model Maker ile metin sınıflandırması

TensorFlow.org'da görüntüleyin Google Colab'de çalıştırın Kaynağı GitHub'da görüntüleyin Defteri indirin

TensorFlow Lite Model Maker kitaplığı, bu modeli cihaz üzerindeki makine öğrenimi uygulamaları için dağıtırken bir TensorFlow modelini özel giriş verilerine uyarlama ve dönüştürme sürecini basitleştirir.

Bu defter, bir mobil cihazdaki film incelemelerini sınıflandırmak için yaygın olarak kullanılan bir metin sınıflandırma modelinin uyarlanmasını ve dönüştürülmesini göstermek için Model Maker kitaplığını kullanan uçtan uca bir örnek göstermektedir. Metin sınıflandırma modeli, metni önceden tanımlanmış kategorilere ayırır. Girdiler önceden işlenmiş metin olmalıdır ve çıktılar, kategorilerin olasılıklarıdır. Bu öğreticide kullanılan veri kümesi olumlu ve olumsuz film incelemeleridir.

Önkoşullar

Gerekli paketleri kurun

Bu örneği çalıştırmak için GitHub deposundan Model Maker paketi de dahil olmak üzere gerekli paketleri yükleyin.

pip install tflite-model-maker
Collecting tflite-model-maker
[?25l  Downloading https://files.pythonhosted.org/packages/13/bc/4c23b9cb9ef612a1f48bac5543bd531665de5eab8f8231111aac067f8c30/tflite_model_maker-0.1.2-py3-none-any.whl (104kB)
[K     |████████████████████████████████| 112kB 8.2MB/s 
[?25hRequirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)
Collecting fire
[?25l  Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)
[K     |████████████████████████████████| 81kB 7.7MB/s 
[?25hCollecting flatbuffers==1.12
  Downloading https://files.pythonhosted.org/packages/eb/26/712e578c5f14e26ae3314c39a1bdc4eb2ec2f4ddc89b708cf8e0a0d20423/flatbuffers-1.12-py2.py3-none-any.whl
Collecting tf-models-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/d3/e9/c4e5a451c268a5a75a27949562364f6086f6bb33b226a065a8beceefa9ba/tf_models_nightly-2.3.0.dev20200914-py2.py3-none-any.whl (993kB)
[K     |████████████████████████████████| 1.0MB 17.6MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 31.6MB/s 
[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)
Collecting tf-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/33/d4/61c47ae889b490b9c5f07f4f61bdc057c158a1a1979c375fa019d647a19e/tf_nightly-2.4.0.dev20200914-cp36-cp36m-manylinux2010_x86_64.whl (390.1MB)
[K     |████████████████████████████████| 390.2MB 46kB/s 
[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (7.0.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.10.0)
Requirement already satisfied: tensorflow-datasets>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)
Collecting tflite-support==0.1.0rc3.dev2
[?25l  Downloading https://files.pythonhosted.org/packages/fa/c5/5e9ee3abd5b4ef8294432cd714407f49a66befa864905b66ee8bdc612795/tflite_support-0.1.0rc3.dev2-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 50.0MB/s 
[?25hRequirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (3.12.4)
Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (1.15.0)
Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->tflite-model-maker) (1.1.0)
Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (2.0.2)
Collecting tensorflow-model-optimization>=0.4.1
[?25l  Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)
[K     |████████████████████████████████| 174kB 57.7MB/s 
[?25hCollecting tf-slim>=1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)
[K     |████████████████████████████████| 358kB 54.9MB/s 
[?25hRequirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.8.3)
Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.5.8)
Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (4.1.3)
Collecting seqeval
  Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz
Requirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.4.1)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.7)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (3.2.2)
Collecting pyyaml>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 55.1MB/s 
[?25hRequirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.29.21)
Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.21.0)
Collecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/b6/2a/496e06fd289c01dc21b11970be1261c87ce1cc22d5340c14b516160822a7/opencv_python_headless-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl (36.6MB)
[K     |████████████████████████████████| 36.6MB 88kB/s 
[?25hRequirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (5.4.8)
Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.3.0)
Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.7.12)
Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.0.5)
Collecting py-cpuinfo>=3.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)
[K     |████████████████████████████████| 102kB 11.1MB/s 
[?25hRequirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.32.0)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.3.0)
Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.3.3)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.35.1)
Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.6.3)
Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (2.10.0)
Requirement already satisfied: typing-extensions>=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.7.4.3)
Collecting tb-nightly<3.0.0a0,>=2.4.0a0
[?25l  Downloading https://files.pythonhosted.org/packages/fc/cb/4dfe0d65bffb5e9663261ff664e6f5a2d37672b31dae27a0f14721ac00d3/tb_nightly-2.4.0a20200914-py3-none-any.whl (10.1MB)
[K     |████████████████████████████████| 10.1MB 46.1MB/s 
[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.2.0)
Collecting tf-estimator-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/bd/9a/3bfb9994eda11e426c809ebdf434e2ac5824a0784d980018bb53fd1620ec/tf_estimator_nightly-2.4.0.dev2020091401-py2.py3-none-any.whl (460kB)
[K     |████████████████████████████████| 460kB 51.7MB/s 
[?25hRequirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.2)
Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.12.1)
Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.24.0)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.16.0)
Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.3)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.23.0)
Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (20.2.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (4.41.1)
Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.3.2)
Collecting pybind11>=2.4
[?25l  Downloading https://files.pythonhosted.org/packages/89/e3/d576f6f02bc75bacbc3d42494e8f1d063c95617d86648dba243c2cb3963e/pybind11-2.5.0-py2.py3-none-any.whl (296kB)
[K     |████████████████████████████████| 296kB 55.2MB/s 
[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorflow-hub>=0.8.0->tflite-model-maker) (50.3.0)
Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-nightly->tflite-model-maker) (0.1.5)
Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly->tflite-model-maker) (2.7.1)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (4.0.1)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.24.3)
Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2020.6.20)
Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (0.0.1)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2.8.1)
Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (4.6)
Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.4.8)
Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.2.8)
Requirement already satisfied: httplib2>=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.17.4)
Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-nightly->tflite-model-maker) (2.4.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (1.2.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.4.7)
Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.0.3)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.0.4)
Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (3.0.1)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (1.17.2)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2018.9)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.2.2)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.0.1)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets>=2.1.0->tflite-model-maker) (1.52.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (3.0.4)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.3)
Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.16.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (4.1.1)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.3.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Building wheels for collected packages: fire, seqeval, pyyaml, py-cpuinfo
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=9eaa2d36e17621d136f8ab1707a5a4e8994c53d5076a9edde21aab7696ba3e09
  Stored in directory: /root/.cache/pip/wheels/c1/61/df/768b03527bf006b546dce284eb4249b185669e65afc5fbb2ac
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-0.0.12-cp36-none-any.whl size=7423 sha256=1ce4604da2a395f0304db708bf2e2c1831033ed8b1f7c23927d70ed9ed7b7110
  Stored in directory: /root/.cache/pip/wheels/4f/32/0a/df3b340a82583566975377d65e724895b3fad101a3fb729f68
  Building wheel for pyyaml (setup.py) ... [?25l[?25hdone
  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44619 sha256=d51b6ef3e90de74d0c1cee8f7aafe0a6d8674348c8437cd89ad5c60a6c3dc726
  Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd
  Building wheel for py-cpuinfo (setup.py) ... [?25l[?25hdone
  Created wheel for py-cpuinfo: filename=py_cpuinfo-7.0.0-cp36-none-any.whl size=20071 sha256=096439bff3cb3e4cc21b86472c629017fd9c972d6e2ed231e1a91d2096fc687d
  Stored in directory: /root/.cache/pip/wheels/f1/93/7b/127daf0c3a5a49feb2fecd468d508067c733fba5192f726ad1
Successfully built fire seqeval pyyaml py-cpuinfo
Installing collected packages: fire, flatbuffers, tensorflow-model-optimization, tf-slim, seqeval, pyyaml, opencv-python-headless, sentencepiece, tb-nightly, tf-estimator-nightly, tf-nightly, py-cpuinfo, tf-models-nightly, pybind11, tflite-support, tflite-model-maker
  Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed fire-0.3.1 flatbuffers-1.12 opencv-python-headless-4.4.0.42 py-cpuinfo-7.0.0 pybind11-2.5.0 pyyaml-5.3.1 sentencepiece-0.1.91 seqeval-0.0.12 tb-nightly-2.4.0a20200914 tensorflow-model-optimization-0.5.0 tf-estimator-nightly-2.4.0.dev2020091401 tf-models-nightly-2.3.0.dev20200914 tf-nightly-2.4.0.dev20200914 tf-slim-1.1.0 tflite-model-maker-0.1.2 tflite-support-0.1.0rc3.dev2

Gerekli paketleri içe aktarın.

import numpy as np
import os

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker import TextClassifierDataLoader

Veri yolunu alın

Bu eğitici için veri kümesini indirin.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')
Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8
7446528/7439277 [==============================] - 0s 0us/step

Bu öğretici üzerinde çalışmak için kendi veri kümenizi de yükleyebilirsiniz. Colab'de sol kenar çubuğunu kullanarak veri kümenizi yükleyin.

Dosya yükleme

Veri kümenizi buluta yüklememeyi tercih ederseniz, kılavuzu izleyerek kütüphaneyi yerel olarak da çalıştırabilirsiniz.

Uçtan Uca İş Akışı

Bu iş akışı, aşağıda özetlendiği gibi beş adımdan oluşur:

Adım 1. Bir metin sınıflandırma modelini temsil eden bir model özelliği seçin.

Bu eğitimde örnek olarak MobileBERT kullanılmıştır.

spec = model_spec.get('mobilebert_classifier')

Adım 2. Cihazdaki bir ML uygulamasına özgü tren ve test verilerini yükleyin ve verileri belirli bir model_spec göre önceden model_spec .

train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=False)

Adım 3. TensorFlow modelini özelleştirin.

model = text_classifier.create(train_data, model_spec=spec)

Adım 4. Modeli değerlendirin.

loss, acc = model.evaluate(test_data)

İle TensorFlow Lite modeli olarak 5. Dışa Adım meta veri .

MobileBERT, cihaz üzerindeki uygulamalar için çok büyük olduğundan, minimum performans düşüşüyle ​​neredeyse 4 kat sıkıştırmak için model üzerinde dinamik aralık nicelemesini kullanın.

config = configs.QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])
config._experimental_new_quantizer = True
model.export(export_dir='mobilebert/', quantization_config=config)

Modeli, Colab'deki sol kenar çubuğunu kullanarak da indirebilirsiniz.

Yukarıdaki 5 adımı uyguladıktan sonra, TensorFlow Lite Görev Kitaplığı'ndaki BertNLClassifier API'yi kullanarak TensorFlow Lite model dosyasını cihaz üzerindeki uygulamalarda da kullanabilirsiniz.

Aşağıdaki bölümlerde, daha fazla ayrıntı göstermek için örnek adım adım açıklanmaktadır.

Metin Sınıflandırıcı için bir Modeli temsil eden bir model_spec seçin

Her model_spec nesnesi, metin sınıflandırıcı için belirli bir modeli temsil eder. TensorFlow Lite Model Maker şu anda MobileBERT'i , ortalama kelime düğünlerini ve BERT-Base modellerini desteklemektedir.

Desteklenen Model Model_spec adı Model Açıklaması
MobileBERT "mobilebert_classifier" Cihaz üstü uygulamalar için uygun rekabetçi sonuçlar elde ederken BERT-Base'den 4,3 kat daha küçük ve 5,5 kat daha hızlı.
BERT-Tabanı "bert_classifier" NLP görevlerinde yaygın olarak kullanılan standart BERT modeli.
ortalama kelime gömme "average_word_vec" RELU aktivasyonu ile metin kelime düğünlerinin ortalaması.

Bu eğitici, süreci göstermek için birden çok kez yeniden average_word_vec daha küçük bir model olan average_word_vec kullanır.

spec = model_spec.get('average_word_vec')

Cihazdaki ML Uygulamasına Özgü Giriş Verilerini Yükleme

SST-2 (Stanford Sentiment Treebank), GLUE kıyaslamasındaki görevlerden biridir. Eğitim için 67.349 film incelemesi ve doğrulama için 872 film incelemesi içerir. Veri kümesinin iki sınıfı vardır: olumlu ve olumsuz film incelemeleri.

Veri kümesinin arşivlenmiş sürümünü indirin ve çıkarın.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')

SST-2 veri kümesi vardır train.tsv eğitim ve için dev.tsv doğrulama için. Dosyalar aşağıdaki biçime sahiptir:

cümle etiket
büyüleyici ve çoğu zaman yolculuğu etkiliyor. 1
korkusuzca kasvetli ve çaresiz 0

Olumlu bir inceleme 1 olarak etiketlenir ve olumsuz bir yorum 0 olarak etiketlenir.

Verileri yüklemek için TestClassifierDataLoader.from_csv yöntemini kullanın.

train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=False)

Model Maker kitaplığı, verileri yüklemek için from_folder() yöntemini de destekler. Aynı sınıfın metin verilerinin aynı alt dizinde olduğunu ve alt klasör adının sınıf adı olduğunu varsayar. Her metin dosyası bir film inceleme örneği içerir. class_labels parametresi, hangi alt klasörlerin olduğunu belirtmek için kullanılır.

TensorFlow Modelini Özelleştirin

Yüklenen verilere göre özel bir metin sınıflandırıcı modeli oluşturun.

model = text_classifier.create(train_data, model_spec=spec, epochs=10)

Ayrıntılı model yapısını inceleyin.

model.summary()

Özelleştirilmiş Modeli Değerlendirin

Modeli test verileri ile değerlendirin ve kaybını ve doğruluğunu alın.

loss, acc = model.evaluate(test_data)

TensorFlow Lite Modeli olarak dışa aktar

Mevcut modeli, daha sonra cihaz üzerindeki bir makine öğrenimi uygulamasında kullanabileceğiniz meta verilerle TensorFlow Lite model formatına dönüştürün. Etiket dosyası ve sözcük dosyası meta verilere yerleştirilmiştir. Varsayılan TFLite dosya adı model.tflite .

model.export(export_dir='average_word_vec/')

TensorFlow Lite model dosyası, TensorFlow Lite Görev Kitaplığı'ndaki NLClassifier API kullanılarak metin sınıflandırma referans uygulamasında kullanılabilir.

İzin verilen dışa aktarma biçimleri aşağıdakilerden biri veya bir listesi olabilir:

  • ExportFormat.TFLITE
  • ExportFormat.LABEL
  • ExportFormat.VOCAB
  • ExportFormat.SAVED_MODEL

Varsayılan olarak, yalnızca TensorFlow Lite modelini meta verilerle dışa aktarır. Ayrıca, farklı dosyaları seçmeli olarak dışa aktarabilirsiniz. Örneğin, yalnızca etiket dosyasını ve sözcük dosyasını aşağıdaki gibi dışa aktarın:

model.export(export_dir='average_word_vec/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])

Doğruluğunu elde etmek için tflite modelini eval_tflite yöntemi ile evaluate_tflite .

accuracy = model.evaluate_tflite('average_word_vec/model.tflite', test_data)

Gelişmiş Kullanım

create fonksiyonu Modeli Üretici kütüphane kullandığı modelleri oluşturmak için o sürücü fonksiyonudur. model_spec parametresi, model spesifikasyonunu tanımlar. AverageWordVecModelSpec ve BertClassifierModelSpec sınıfları şu anda desteklenmektedir. create işlevi aşağıdaki adımlardan oluşur:

  1. model_spec göre metin sınıflandırıcı için model oluşturur.
  2. Sınıflandırıcı modelini eğitir. Varsayılan dönemler ve varsayılan parti boyutu, model_spec nesnesindeki default_training_epochs ve default_batch_size değişkenleri tarafından ayarlanır.

Bu bölüm, modeli ayarlama ve hiperparametrelerini eğitme gibi gelişmiş kullanım konularını kapsar.

Modeli ayarlayın

AverageWordVecModelSpec sınıfındaki wordvec_dim ve seq_len değişkenleri gibi model altyapısını ayarlayabilirsiniz.

Örneğin, modeli daha büyük bir wordvec_dim değeri ile wordvec_dim . Modeli değiştirirseniz yeni bir model_spec gerektiğini unutmayın.

new_model_spec = model_spec.AverageWordVecModelSpec(wordvec_dim=32)

Önceden işlenmiş verileri alın.

new_train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=new_model_spec,
      delimiter='\t',
      is_training=True)

Yeni modeli eğitin.

model = text_classifier.create(new_train_data, model_spec=new_model_spec)

Ayrıca MobileBERT modelini de ayarlayabilirsiniz.

Ayarlayabileceğiniz model parametreleri şunlardır:

  • seq_len : seq_len dizinin uzunluğu.
  • initializer_range : Tüm ağırlık matrislerini başlatmak için truncated_normal_initializer standart sapması.
  • trainable : Önceden eğitilmiş katmanın eğitilebilir olup olmadığını belirten Boole.

Ayarlayabileceğiniz eğitim hattı parametreleri şunlardır:

  • model_dir : Model kontrol noktası dosyalarının konumu. Ayarlanmazsa, geçici bir dizin kullanılacaktır.
  • dropout_rate : Bırakma oranı.
  • learning_rate : Adam optimize edicinin ilk öğrenme oranı.
  • tpu : Bağlanılacak TPU adresi.

Örneğin, seq_len=256 (varsayılan 128'dir) olarak ayarlayabilirsiniz. Bu, modelin daha uzun metni sınıflandırmasına izin verir.

new_model_spec = model_spec.get('mobilebert_classifier')
new_model_spec.seq_len = 256

Eğitim hiperparametrelerini ayarlayın

Ayrıca, model doğruluğunu etkileyen epochs ve batch_size gibi eğitim hiperparametrelerini de ayarlayabilirsiniz. Örneğin,

  • epochs : daha fazla çağ daha iyi doğruluk sağlayabilir, ancak aşırı uyuma neden olabilir.
  • batch_size : bir eğitim adımında kullanılacak örnek sayısı.

Örneğin, daha fazla dönemle antrenman yapabilirsiniz.

model = text_classifier.create(train_data, model_spec=spec, epochs=20)

Yeni yeniden eğitilmiş modeli 20 eğitim dönemi ile değerlendirin.

loss, accuracy = model.evaluate(test_data)

Model Mimarisini Değiştirin

model_spec değiştirerek modeli değiştirebilirsiniz. Aşağıda BERT-Base modeline nasıl geçileceği gösterilmektedir.

Metin sınıflandırıcı için model_spec BERT-Base modeline değiştirin.

spec = model_spec.get('bert_classifier')

Kalan adımlar aynı.