Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Clasificación de imágenes con TensorFlow Lite Model Maker

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno Ver modelo TF Hub

La biblioteca Model Maker simplifica el proceso de adaptación y conversión de un modelo de red neuronal de TensorFlow en datos de entrada particulares al implementar este modelo para aplicaciones de AA en el dispositivo.

Este cuaderno muestra un ejemplo de extremo a extremo que utiliza esta biblioteca de Model Maker para ilustrar la adaptación y conversión de un modelo de clasificación de imágenes de uso común para clasificar flores en un dispositivo móvil.

Prerrequisitos

Para ejecutar este ejemplo, primero debemos instalar varios paquetes necesarios, incluido el paquete Model Maker que está en el repositorio de GitHub.

pip install tflite-model-maker
Collecting tflite-model-maker
[?25l  Downloading https://files.pythonhosted.org/packages/13/bc/4c23b9cb9ef612a1f48bac5543bd531665de5eab8f8231111aac067f8c30/tflite_model_maker-0.1.2-py3-none-any.whl (104kB)
[K     |████████████████████████████████| 112kB 3.0MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 8.9MB/s 
[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)
Collecting tflite-support==0.1.0rc3.dev2
[?25l  Downloading https://files.pythonhosted.org/packages/fa/c5/5e9ee3abd5b4ef8294432cd714407f49a66befa864905b66ee8bdc612795/tflite_support-0.1.0rc3.dev2-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 15.4MB/s 
[?25hRequirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)
Requirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (7.0.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.10.0)
Collecting flatbuffers==1.12
  Downloading https://files.pythonhosted.org/packages/eb/26/712e578c5f14e26ae3314c39a1bdc4eb2ec2f4ddc89b708cf8e0a0d20423/flatbuffers-1.12-py2.py3-none-any.whl
Collecting tf-models-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/d3/e9/c4e5a451c268a5a75a27949562364f6086f6bb33b226a065a8beceefa9ba/tf_models_nightly-2.3.0.dev20200914-py2.py3-none-any.whl (993kB)
[K     |████████████████████████████████| 1.0MB 23.3MB/s 
[?25hCollecting fire
[?25l  Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)
[K     |████████████████████████████████| 81kB 9.1MB/s 
[?25hRequirement already satisfied: tensorflow-datasets>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)
Collecting tf-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/33/d4/61c47ae889b490b9c5f07f4f61bdc057c158a1a1979c375fa019d647a19e/tf_nightly-2.4.0.dev20200914-cp36-cp36m-manylinux2010_x86_64.whl (390.1MB)
[K     |████████████████████████████████| 390.2MB 43kB/s 
[?25hCollecting pybind11>=2.4
[?25l  Downloading https://files.pythonhosted.org/packages/89/e3/d576f6f02bc75bacbc3d42494e8f1d063c95617d86648dba243c2cb3963e/pybind11-2.5.0-py2.py3-none-any.whl (296kB)
[K     |████████████████████████████████| 296kB 50.8MB/s 
[?25hRequirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (1.15.0)
Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (3.12.4)
Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.3.0)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (3.2.2)
Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.5.8)
Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.0.5)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.7)
Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.7.12)
Collecting py-cpuinfo>=3.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)
[K     |████████████████████████████████| 102kB 13.7MB/s 
[?25hCollecting tf-slim>=1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)
[K     |████████████████████████████████| 358kB 46.5MB/s 
[?25hCollecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/b6/2a/496e06fd289c01dc21b11970be1261c87ce1cc22d5340c14b516160822a7/opencv_python_headless-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl (36.6MB)
[K     |████████████████████████████████| 36.6MB 92kB/s 
[?25hRequirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.4.1)
Collecting tensorflow-model-optimization>=0.4.1
[?25l  Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)
[K     |████████████████████████████████| 174kB 54.9MB/s 
[?25hRequirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.21.0)
Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (4.1.3)
Collecting seqeval
  Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz
Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (5.4.8)
Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (2.0.2)
Collecting pyyaml>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 47.1MB/s 
[?25hRequirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.29.21)
Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.8.3)
Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->tflite-model-maker) (1.1.0)
Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.24.0)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.16.0)
Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.3)
Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (20.2.0)
Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.3.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (4.41.1)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.23.0)
Requirement already satisfied: wrapt in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (1.12.1)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.35.1)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.3.0)
Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.32.0)
Collecting tb-nightly<3.0.0a0,>=2.4.0a0
[?25l  Downloading https://files.pythonhosted.org/packages/fc/cb/4dfe0d65bffb5e9663261ff664e6f5a2d37672b31dae27a0f14721ac00d3/tb_nightly-2.4.0a20200914-py3-none-any.whl (10.1MB)
[K     |████████████████████████████████| 10.1MB 45.9MB/s 
[?25hRequirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.2)
Requirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.2.0)
Collecting tf-estimator-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/bd/9a/3bfb9994eda11e426c809ebdf434e2ac5824a0784d980018bb53fd1620ec/tf_estimator_nightly-2.4.0.dev2020091401-py2.py3-none-any.whl (460kB)
[K     |████████████████████████████████| 460kB 45.6MB/s 
[?25hRequirement already satisfied: typing-extensions>=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.7.4.3)
Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (2.10.0)
Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.3.3)
Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.6.3)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorflow-hub>=0.8.0->tflite-model-maker) (50.3.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.8.1)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.4.7)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (0.10.0)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.24.3)
Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (0.0.1)
Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2020.6.20)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (4.0.1)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2018.9)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (1.17.2)
Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (3.0.1)
Requirement already satisfied: httplib2<1dev,>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.17.4)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.0.4)
Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-nightly->tflite-model-maker) (0.1.5)
Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.0.3)
Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.4.8)
Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (4.6)
Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.2.8)
Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-nightly->tflite-model-maker) (2.4.3)
Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly->tflite-model-maker) (2.7.1)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets>=2.1.0->tflite-model-maker) (1.52.0)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (2.10)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.0.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.2.2)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.3)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (4.1.1)
Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.16.0)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.3.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Building wheels for collected packages: fire, py-cpuinfo, seqeval, pyyaml
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=8f09a5a04716eb30229b33f5a9031fa22413bd4f709aac5155f4f26c6b070f47
  Stored in directory: /root/.cache/pip/wheels/c1/61/df/768b03527bf006b546dce284eb4249b185669e65afc5fbb2ac
  Building wheel for py-cpuinfo (setup.py) ... [?25l[?25hdone
  Created wheel for py-cpuinfo: filename=py_cpuinfo-7.0.0-cp36-none-any.whl size=20071 sha256=574e1452bf1fb528837233653837cf69e38804b69190421918f8570a6f5f7c79
  Stored in directory: /root/.cache/pip/wheels/f1/93/7b/127daf0c3a5a49feb2fecd468d508067c733fba5192f726ad1
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-0.0.12-cp36-none-any.whl size=7423 sha256=788a558edd9264e4bbc86ed4a69b393b367e12e33a8c922f64289530f289f1c6
  Stored in directory: /root/.cache/pip/wheels/4f/32/0a/df3b340a82583566975377d65e724895b3fad101a3fb729f68
  Building wheel for pyyaml (setup.py) ... [?25l[?25hdone
  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44619 sha256=12850ae3031f2d470b0d073f988afc480e64941f0fdf179c25fb17e03a39d550
  Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd
Successfully built fire py-cpuinfo seqeval pyyaml
Installing collected packages: sentencepiece, pybind11, flatbuffers, tflite-support, py-cpuinfo, tf-slim, opencv-python-headless, tensorflow-model-optimization, tb-nightly, tf-estimator-nightly, tf-nightly, seqeval, pyyaml, tf-models-nightly, fire, tflite-model-maker
  Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed fire-0.3.1 flatbuffers-1.12 opencv-python-headless-4.4.0.42 py-cpuinfo-7.0.0 pybind11-2.5.0 pyyaml-5.3.1 sentencepiece-0.1.91 seqeval-0.0.12 tb-nightly-2.4.0a20200914 tensorflow-model-optimization-0.5.0 tf-estimator-nightly-2.4.0.dev2020091401 tf-models-nightly-2.3.0.dev20200914 tf-nightly-2.4.0.dev20200914 tf-slim-1.1.0 tflite-model-maker-0.1.2 tflite-support-0.1.0rc3.dev2

Importe los paquetes necesarios.

import numpy as np

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import image_classifier
from tflite_model_maker import ImageClassifierDataLoader
from tflite_model_maker import model_spec

import matplotlib.pyplot as plt

Ejemplo simple de extremo a extremo

Obtener la ruta de datos

Consigamos algunas imágenes para jugar con este sencillo ejemplo de principio a fin. Cientos de imágenes es un buen comienzo para Model Maker, mientras que más datos podrían lograr una mayor precisión.

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 1s 0us/step

Puede reemplazar image_path con sus propias carpetas de imágenes. En cuanto a la carga de datos a colab, puede encontrar el botón de carga en la barra lateral izquierda que se muestra en la imagen de abajo con el rectángulo rojo. Solo intente cargar un archivo zip y descomprimirlo. La ruta del archivo raíz es la ruta actual.

Subir archivo

Si prefiere no cargar sus imágenes en la nube, puede intentar ejecutar la biblioteca localmente siguiendo la guía en GitHub.

Ejecute el ejemplo

El ejemplo solo consta de 4 líneas de código como se muestra a continuación, cada una de las cuales representa un paso del proceso general.

Paso 1. Cargue datos de entrada específicos para una aplicación de AA en el dispositivo. Divídalo en datos de entrenamiento y datos de prueba.

data = ImageClassifierDataLoader.from_folder(image_path)
train_data, test_data = data.split(0.9)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.

Paso 2. Personaliza el modelo de TensorFlow.

model = image_classifier.create(train_data)

Paso 3. Evalúe el modelo.

loss, accuracy = model.evaluate(test_data)

Paso 4. Exportar al modelo de TensorFlow Lite.

Aquí, exportamos el modelo de TensorFlow Lite con metadatos que proporciona un estándar para las descripciones del modelo. El archivo de etiqueta está incrustado en metadatos.

Puede descargarlo en la barra lateral izquierda al igual que la parte de carga para su propio uso.

model.export(export_dir='.')

Después de estos simples 4 pasos, podríamos usar más el archivo de modelo de TensorFlow Lite en aplicaciones en el dispositivo, como en la aplicación de referencia de clasificación de imágenes .

Proceso detallado

Actualmente, admitimos varios modelos como los modelos EfficientNet-Lite *, MobileNetV2, ResNet50 como modelos previamente entrenados para la clasificación de imágenes. Pero es muy flexible agregar nuevos modelos previamente entrenados a esta biblioteca con solo unas pocas líneas de código.

A continuación, se muestra este ejemplo de extremo a extremo paso a paso para mostrar más detalles.

Paso 1: cargar datos de entrada específicos para una aplicación de aprendizaje automático en el dispositivo

El conjunto de datos de flores contiene 3670 imágenes que pertenecen a 5 clases. Descargue la versión de archivo del conjunto de datos y descomprímalo.

El conjunto de datos tiene la siguiente estructura de directorios:

flower_photos
|__ daisy
    |______ 100080576_f52e8ee070_n.jpg
    |______ 14167534527_781ceb1b7a_n.jpg
    |______ ...
|__ dandelion
    |______ 10043234166_e6dd915111_n.jpg
    |______ 1426682852_e62169221f_m.jpg
    |______ ...
|__ roses
    |______ 102501987_3cdb8e5394_n.jpg
    |______ 14982802401_a3dfb22afb.jpg
    |______ ...
|__ sunflowers
    |______ 12471791574_bb1be83df4.jpg
    |______ 15122112402_cafa41934f.jpg
    |______ ...
|__ tulips
    |______ 13976522214_ccec508fe7.jpg
    |______ 14487943607_651e8062a1_m.jpg
    |______ ...
image_path = tf.keras.utils.get_file(
      'flower_photos',
      'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
      untar=True)

Utilice la clase ImageClassifierDataLoader para cargar datos.

En cuanto from_folder() método from_folder() , podría cargar datos desde la carpeta. Supone que los datos de imagen de la misma clase están en el mismo subdirectorio y que el nombre de la subcarpeta es el nombre de la clase. Actualmente, se admiten imágenes codificadas en JPEG e imágenes codificadas en PNG.

data = ImageClassifierDataLoader.from_folder(image_path)

Divídalo en datos de entrenamiento (80%), datos de validación (10%, opcional) y datos de prueba (10%).

train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)

Muestre 25 ejemplos de imágenes con etiquetas.

plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(data.dataset.take(25)):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)
  plt.xlabel(data.index_to_label[label.numpy()])
plt.show()

Paso 2: personaliza el modelo de TensorFlow

Cree un modelo de clasificador de imágenes personalizado basado en los datos cargados. El modelo predeterminado es EfficientNet-Lite0.

model = image_classifier.create(train_data, validation_data=validation_data)

Eche un vistazo a la estructura detallada del modelo.

model.summary()

Paso 3: evaluar el modelo personalizado

Evalúe el resultado del modelo, obtenga la pérdida y precisión del modelo.

loss, accuracy = model.evaluate(test_data)

Podríamos trazar los resultados previstos en 100 imágenes de prueba. Las etiquetas pronosticadas con color rojo son los resultados predichos incorrectos, mientras que otras son correctas.

# A helper function that returns 'red'/'black' depending on if its two input
# parameter matches or not.
def get_label_color(val1, val2):
  if val1 == val2:
    return 'black'
  else:
    return 'red'

# Then plot 100 test images and their predicted labels.
# If a prediction result is different from the label provided label in "test"
# dataset, we will highlight it in red color.
plt.figure(figsize=(20, 20))
predicts = model.predict_top_k(test_data)
for i, (image, label) in enumerate(test_data.dataset.take(100)):
  ax = plt.subplot(10, 10, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)

  predict_label = predicts[i][0][0]
  color = get_label_color(predict_label,
                          test_data.index_to_label[label.numpy()])
  ax.xaxis.label.set_color(color)
  plt.xlabel('Predicted: %s' % predict_label)
plt.show()

Si la precisión no cumple con los requisitos de la aplicación, se puede consultar el uso avanzado para explorar alternativas como cambiar a un modelo más grande, ajustar los parámetros de reentrenamiento, etc.

Paso 4: Exportar al modelo de TensorFlow Lite

Convierta el modelo existente al formato de modelo de TensorFlow Lite con metadatos . El nombre de archivo TFLite predeterminado es model.tflite .

model.export(export_dir='.')

Consulte las aplicaciones de ejemplo y las guías de clasificación de imágenes para obtener más detalles sobre cómo integrar el modelo TensorFlow Lite en aplicaciones móviles.

Los formatos de exportación permitidos pueden ser uno o una lista de los siguientes:

  • ExportFormat.TFLITE
  • ExportFormat.LABEL
  • ExportFormat.SAVED_MODEL

De forma predeterminada, solo exporta el modelo de TensorFlow Lite con metadatos. También puede exportar diferentes archivos de forma selectiva. Por ejemplo, exportar solo el archivo de etiqueta de la siguiente manera:

model.export(export_dir='.', export_format=ExportFormat.LABEL)

También se puede evaluar el modelo tflite con el evaluate_tflite método.

model.evaluate_tflite('model.tflite', test_data)

Uso avanzado

La función de create es la parte crítica de esta biblioteca. Utiliza el aprendizaje por transferencia con un modelo previamente entrenado similar al tutorial .

La función de create contiene los siguientes pasos:

  1. Divida los datos en datos de entrenamiento, validación y prueba de acuerdo con el parámetro validation_ratio y test_ratio . El valor predeterminado de validation_ratio y test_ratio son 0.1 y 0.1 .
  2. Descarga un vector de características de imagen como modelo base de TensorFlow Hub. El modelo pre-entrenado predeterminado es EfficientNet-Lite0.
  3. Agregue un cabezal clasificador con una capa de abandono con dropout_rate entre la capa del cabezal y el modelo previamente entrenado. El dropout_rate predeterminado es el valor predeterminado de dropout_rate de make_image_classifier_lib de TensorFlow Hub.
  4. Procese previamente los datos de entrada sin procesar. Actualmente, los pasos de preprocesamiento incluyen normalizar el valor de cada píxel de imagen para modelar la escala de entrada y cambiar su tamaño al tamaño de entrada del modelo. EfficientNet-Lite0 tiene la escala de entrada [0, 1] y el tamaño de la imagen de entrada [224, 224, 3] .
  5. Introduzca los datos en el modelo de clasificador. De forma predeterminada, los parámetros de entrenamiento, como las épocas de entrenamiento, el tamaño del lote, la tasa de aprendizaje y el impulso, son los valores predeterminados de make_image_classifier_lib de TensorFlow Hub. Solo se entrena el cabezal clasificador.

En esta sección, describimos varios temas avanzados, incluido el cambio a un modelo de clasificación de imágenes diferente, el cambio de los hiperparámetros de entrenamiento, etc.

Cuantización posterior al entrenamiento en el modelo TensorFLow Lite

La cuantificación posterior al entrenamiento es una técnica de conversión que puede reducir el tamaño del modelo y la latencia de inferencia, al mismo tiempo que mejora la latencia del acelerador de CPU y hardware, con poca degradación en la precisión del modelo. Por tanto, se utiliza mucho para optimizar el modelo.

Model Maker admite múltiples opciones de cuantificación posteriores al entrenamiento. Tomemos como ejemplo la cuantificación completa de enteros. Primero, defina la configuración de cuantificación para hacer cumplir la cuantificación entera completa para todas las operaciones, incluida la entrada y la salida. El tipo de entrada y el tipo de salida son uint8 por defecto. También puede cambiarlos a otros tipos como int8 configurando inference_input_type e inference_output_type en config.

config = configs.QuantizationConfig.create_full_integer_quantization(representative_data=test_data, is_integer_only=True)

Luego exportamos el modelo de TensorFlow Lite con dicha configuración.

model.export(export_dir='.', tflite_filename='model_quant.tflite', quantization_config=config)

En Colab, puede descargar el modelo llamado model_quant.tflite de la barra lateral izquierda, igual que la parte de carga mencionada anteriormente.

Cambiar el modelo

Cambie al modelo compatible con esta biblioteca.

Esta biblioteca es compatible con los modelos EfficientNet-Lite, MobileNetV2, ResNet50 por ahora. EfficientNet-Lite es una familia de modelos de clasificación de imágenes que pueden lograr una precisión de vanguardia y son adecuados para dispositivos Edge. El modelo predeterminado es EfficientNet-Lite0.

Podríamos cambiar el modelo a MobileNetV2 simplemente estableciendo el parámetro model_spec en mobilenet_v2_spec en el método de create .

model = image_classifier.create(train_data, model_spec=model_spec.mobilenet_v2_spec, validation_data=validation_data)

Evalúe el modelo MobileNetV2 recientemente reentrenado para ver la precisión y la pérdida en los datos de prueba.

loss, accuracy = model.evaluate(test_data)

Cambiar al modelo en TensorFlow Hub

Además, también podríamos cambiar a otros modelos nuevos que ingresen una imagen y generen un vector de características con el formato TensorFlow Hub.

Como modelo Inception V3 como ejemplo, podríamos definir inception_v3_spec que es un objeto de ImageModelSpec y contiene la especificación del modelo Inception V3.

Tenemos que especificar el nombre del modelo name , la dirección URL del modelo TensorFlow Hub uri . Mientras tanto, el valor predeterminado de input_image_shape es [224, 224] . Necesitamos cambiarlo a [299, 299] para el modelo Inception V3.

inception_v3_spec = model_spec.ImageModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]

Luego, estableciendo el parámetro model_spec en inception_v3_spec en el método de create , podríamos model_spec a entrenar el modelo Inception V3.

Los pasos restantes son exactamente los mismos y al final podríamos obtener un modelo personalizado de InceptionV3 TensorFlow Lite.

Cambia tu propio modelo personalizado

Si nos gustaría utilizar el modelo personalizado que no está en TensorFlow Hub, deberíamos crear y exportar ModelSpec en TensorFlow Hub.

Luego comience a definir el objeto ImageModelSpec como el proceso anterior.

Cambiar los hiperparámetros de entrenamiento

También podríamos cambiar los hiperparámetros formación como epochs , dropout_rate y batch_size que podrían afectar a la precisión del modelo. Los parámetros del modelo que puede ajustar son:

  • epochs : más épocas podrían lograr una mejor precisión hasta que converja, pero el entrenamiento para demasiadas épocas puede conducir a un sobreajuste.
  • dropout_rate : la tasa de abandono, evitar el sobreajuste. Ninguno por defecto.
  • batch_size : número de muestras que se utilizarán en un paso de formación. Ninguno por defecto.
  • validation_data : datos de validación. Si es Ninguno, omite el proceso de validación. Ninguno por defecto.
  • train_whole_model : si es verdadero, el módulo Hub se entrena junto con la capa de clasificación en la parte superior. De lo contrario, entrene solo la capa de clasificación superior. Ninguno por defecto.
  • learning_rate : tasa de aprendizaje base. Ninguno por defecto.
  • momentum : un flotador de Python enviado al optimizador. Solo se usa cuando use_hub_library es True. Ninguno por defecto.
  • shuffle : booleano, si los datos deben mezclarse. Falso por defecto.
  • use_augmentation : booleano, use el aumento de datos para el preprocesamiento. Falso por defecto.
  • use_hub_library : booleano, use make_image_classifier_lib de tensorflow hub para volver a entrenar el modelo. Esta canalización de capacitación podría lograr un mejor rendimiento para conjuntos de datos complicados con muchas categorías. Verdadero por defecto.
  • warmup_steps : Número de pasos de calentamiento para el programa de calentamiento según la tasa de aprendizaje. Si es Ninguno, se usa el warmup_steps predeterminado, que es el total de pasos de entrenamiento en dos épocas. Solo se usa cuando use_hub_library es False. Ninguno por defecto.
  • model_dir : opcional, la ubicación de los archivos del punto de control del modelo. Solo se usa cuando use_hub_library es False. Ninguno por defecto.

Los parámetros que son Ninguno de forma predeterminada, como epochs , obtendrán los parámetros predeterminados concretos en make_image_classifier_lib de la biblioteca de TensorFlow Hub o train_image_classifier_lib .

Por ejemplo, podríamos entrenar con más épocas.

model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)

Evalúe el modelo recién reentrenado con 10 épocas de entrenamiento.

loss, accuracy = model.evaluate(test_data)