![]() | ![]() | ![]() | ![]() | ![]() |
La biblioteca fabricante TensorFlow Lite modelo simplifica el proceso de adaptación y conversión de un modelo de red neuronal TensorFlow para introducir datos particulares al implementar este modelo para aplicaciones de LD en el dispositivo.
Este cuaderno muestra un ejemplo de un extremo a otro que utiliza esta biblioteca de Model Maker para ilustrar la adaptación y conversión de un modelo de clasificación de imágenes de uso común para clasificar flores en un dispositivo móvil.
Prerrequisitos
Para ejecutar este ejemplo, primero es necesario instalar varios paquetes requeridos, incluyendo Modelo Fabricante de paquete que, en GitHub repo .
pip install -q tflite-model-maker
Importe los paquetes necesarios.
import os
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2')
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release PkgResourcesDeprecationWarning, /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/numba/core/errors.py:168: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9 warnings.warn(msg)
Ejemplo simple de extremo a extremo
Obtener la ruta de datos
Consigamos algunas imágenes para jugar con este sencillo ejemplo de principio a fin. Cientos de imágenes es un buen comienzo para Model Maker, mientras que más datos podrían lograr una mayor precisión.
image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 228818944/228813984 [==============================] - 1s 0us/step 228827136/228813984 [==============================] - 1s 0us/step
Se podría sustituir image_path
con sus propias carpetas de imágenes. En cuanto a la carga de datos a colab, puede encontrar el botón de carga en la barra lateral izquierda que se muestra en la imagen de abajo con el rectángulo rojo. Intente cargar un archivo zip y descomprímalo. La ruta del archivo raíz es la ruta actual.
Si usted prefiere no subir sus imágenes a la nube, usted podría tratar de ejecutar la biblioteca local después de la guía en GitHub.
Ejecuta el ejemplo
El ejemplo solo consta de 4 líneas de código como se muestra a continuación, cada una de las cuales representa un paso del proceso general.
Paso 1. Cargue los datos de entrada específicos de una aplicación de aprendizaje automático en el dispositivo. Divídalo en datos de entrenamiento y datos de prueba.
data = DataLoader.from_folder(image_path)
train_data, test_data = data.split(0.9)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.
Paso 2. Personaliza el modelo de TensorFlow.
model = image_classifier.create(train_data)
INFO:tensorflow:Retraining the models... Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2 (HubKer (None, 1280) 3413024 _________________________________________________________________ dropout (Dropout) (None, 1280) 0 _________________________________________________________________ dense (Dense) (None, 5) 6405 ================================================================= Total params: 3,419,429 Trainable params: 6,405 Non-trainable params: 3,413,024 _________________________________________________________________ None Epoch 1/5 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.") 103/103 [==============================] - 7s 35ms/step - loss: 0.8551 - accuracy: 0.7718 Epoch 2/5 103/103 [==============================] - 4s 35ms/step - loss: 0.6503 - accuracy: 0.8956 Epoch 3/5 103/103 [==============================] - 4s 34ms/step - loss: 0.6157 - accuracy: 0.9196 Epoch 4/5 103/103 [==============================] - 3s 33ms/step - loss: 0.6036 - accuracy: 0.9293 Epoch 5/5 103/103 [==============================] - 4s 34ms/step - loss: 0.5929 - accuracy: 0.9317
Paso 3. Evalúe el modelo.
loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 2s 40ms/step - loss: 0.6282 - accuracy: 0.9019
Paso 4. Exportar al modelo de TensorFlow Lite.
Aquí, exportamos modelo TensorFlow Lite con metadatos que proporciona un estándar para descripciones de los modelos. El archivo de etiqueta está incrustado en metadatos. La técnica de cuantificación posterior al entrenamiento predeterminada es la cuantificación entera completa para la tarea de clasificación de imágenes.
Puede descargarlo en la barra lateral izquierda al igual que la parte de carga para su propio uso.
model.export(export_dir='.')
2021-11-02 11:34:05.568024: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/tmpkqikzotp/assets INFO:tensorflow:Assets written to: /tmp/tmpkqikzotp/assets 2021-11-02 11:34:09.488041: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-11-02 11:34:09.488090: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3 WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Saving labels in /tmp/tmpoblx4ed5/labels.txt INFO:tensorflow:Saving labels in /tmp/tmpoblx4ed5/labels.txt INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
Después de estos sencillos pasos 4, podríamos utilizar más TensorFlow archivo de modelo Lite en aplicaciones en el dispositivo como en la imagen de la clasificación aplicación de referencia.
Proceso detallado
Actualmente, admitimos varios modelos, como los modelos EfficientNet-Lite *, MobileNetV2, ResNet50 como modelos previamente entrenados para la clasificación de imágenes. Pero es muy flexible agregar nuevos modelos previamente entrenados a esta biblioteca con solo unas pocas líneas de código.
A continuación, se muestra este ejemplo de un extremo a otro paso a paso para mostrar más detalles.
Paso 1: Cargar datos de entrada específicos para una aplicación de aprendizaje automático en el dispositivo
El conjunto de datos de flores contiene 3670 imágenes que pertenecen a 5 clases. Descargue la versión de archivo del conjunto de datos y descomprímalo.
El conjunto de datos tiene la siguiente estructura de directorios:
flower_photos |__ daisy |______ 100080576_f52e8ee070_n.jpg |______ 14167534527_781ceb1b7a_n.jpg |______ ... |__ dandelion |______ 10043234166_e6dd915111_n.jpg |______ 1426682852_e62169221f_m.jpg |______ ... |__ roses |______ 102501987_3cdb8e5394_n.jpg |______ 14982802401_a3dfb22afb.jpg |______ ... |__ sunflowers |______ 12471791574_bb1be83df4.jpg |______ 15122112402_cafa41934f.jpg |______ ... |__ tulips |______ 13976522214_ccec508fe7.jpg |______ 14487943607_651e8062a1_m.jpg |______ ...
image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
Uso DataLoader
clase para cargar datos.
En cuanto a from_folder()
método, se pudo cargar los datos de la carpeta. Se asume que los datos de imagen de la misma clase están en el mismo subdirectorio y que el nombre de la subcarpeta es el nombre de la clase. Actualmente, se admiten imágenes codificadas en JPEG e imágenes codificadas en PNG.
data = DataLoader.from_folder(image_path)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips. INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.
Divídalo en datos de entrenamiento (80%), datos de validación (10%, opcional) y datos de prueba (10%).
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
Muestre 25 ejemplos de imágenes con etiquetas.
plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(image.numpy(), cmap=plt.cm.gray)
plt.xlabel(data.index_to_label[label.numpy()])
plt.show()
Paso 2: personaliza el modelo de TensorFlow
Cree un modelo de clasificador de imágenes personalizado basado en los datos cargados. El modelo predeterminado es EfficientNet-Lite0.
model = image_classifier.create(train_data, validation_data=validation_data)
INFO:tensorflow:Retraining the models... INFO:tensorflow:Retraining the models... Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2_1 (HubK (None, 1280) 3413024 _________________________________________________________________ dropout_1 (Dropout) (None, 1280) 0 _________________________________________________________________ dense_1 (Dense) (None, 5) 6405 ================================================================= Total params: 3,419,429 Trainable params: 6,405 Non-trainable params: 3,413,024 _________________________________________________________________ None Epoch 1/5 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.") 91/91 [==============================] - 6s 54ms/step - loss: 0.8689 - accuracy: 0.7655 - val_loss: 0.6941 - val_accuracy: 0.8835 Epoch 2/5 91/91 [==============================] - 5s 50ms/step - loss: 0.6596 - accuracy: 0.8949 - val_loss: 0.6668 - val_accuracy: 0.8807 Epoch 3/5 91/91 [==============================] - 5s 50ms/step - loss: 0.6188 - accuracy: 0.9159 - val_loss: 0.6537 - val_accuracy: 0.8807 Epoch 4/5 91/91 [==============================] - 5s 52ms/step - loss: 0.6050 - accuracy: 0.9210 - val_loss: 0.6432 - val_accuracy: 0.8892 Epoch 5/5 91/91 [==============================] - 5s 52ms/step - loss: 0.5898 - accuracy: 0.9348 - val_loss: 0.6348 - val_accuracy: 0.8864
Eche un vistazo a la estructura detallada del modelo.
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2_1 (HubK (None, 1280) 3413024 _________________________________________________________________ dropout_1 (Dropout) (None, 1280) 0 _________________________________________________________________ dense_1 (Dense) (None, 5) 6405 ================================================================= Total params: 3,419,429 Trainable params: 6,405 Non-trainable params: 3,413,024 _________________________________________________________________
Paso 3: evaluar el modelo personalizado
Evalúe el resultado del modelo, obtenga la pérdida y precisión del modelo.
loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 27ms/step - loss: 0.6324 - accuracy: 0.8965
Podríamos trazar los resultados predichos en 100 imágenes de prueba. Las etiquetas pronosticadas con color rojo son los resultados predichos incorrectos, mientras que otras son correctas.
# A helper function that returns 'red'/'black' depending on if its two input
# parameter matches or not.
def get_label_color(val1, val2):
if val1 == val2:
return 'black'
else:
return 'red'
# Then plot 100 test images and their predicted labels.
# If a prediction result is different from the label provided label in "test"
# dataset, we will highlight it in red color.
plt.figure(figsize=(20, 20))
predicts = model.predict_top_k(test_data)
for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(100)):
ax = plt.subplot(10, 10, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(image.numpy(), cmap=plt.cm.gray)
predict_label = predicts[i][0][0]
color = get_label_color(predict_label,
test_data.index_to_label[label.numpy()])
ax.xaxis.label.set_color(color)
plt.xlabel('Predicted: %s' % predict_label)
plt.show()
Si la precisión no cumple con el requisito de aplicación, se podría hacer referencia a Uso avanzado para explorar alternativas tales como cambiar a un modelo más grande, el ajuste de los parámetros de re-formación, etc.
Paso 4: Exportar al modelo de TensorFlow Lite
Convertir el modelo entrenado a formato modelo TensorFlow Lite con metadatos para que pueda utilizar más tarde en una aplicación ML en el dispositivo. El archivo de etiqueta y el archivo de vocabulario están incrustados en metadatos. El nombre de archivo por defecto es TFLite model.tflite
.
En muchas aplicaciones de aprendizaje automático en el dispositivo, el tamaño del modelo es un factor importante. Por lo tanto, se recomienda que aplique cuantificar el modelo para hacerlo más pequeño y potencialmente ejecutar más rápido. La técnica de cuantificación posterior al entrenamiento predeterminada es la cuantificación entera completa para la tarea de clasificación de imágenes.
model.export(export_dir='.')
INFO:tensorflow:Assets written to: /tmp/tmp6tt5g8de/assets INFO:tensorflow:Assets written to: /tmp/tmp6tt5g8de/assets 2021-11-02 11:35:40.254046: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-11-02 11:35:40.254099: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3 WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Saving labels in /tmp/tmpf601xty1/labels.txt INFO:tensorflow:Saving labels in /tmp/tmpf601xty1/labels.txt INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
Ver ejemplos de aplicaciones y guías de clasificación de imágenes para más detalles acerca de cómo integrar el modelo TensorFlow Lite en aplicaciones para móviles.
Este modelo se puede integrar en un Android o una aplicación para iOS usando la API ImageClassifier de la biblioteca de tareas Lite TensorFlow .
Los formatos de exportación permitidos pueden ser uno o una lista de los siguientes:
De forma predeterminada, solo exporta el modelo de TensorFlow Lite con metadatos. También puede exportar diferentes archivos de forma selectiva. Por ejemplo, exportar solo el archivo de etiqueta de la siguiente manera:
model.export(export_dir='.', export_format=ExportFormat.LABEL)
INFO:tensorflow:Saving labels in ./labels.txt INFO:tensorflow:Saving labels in ./labels.txt
También se puede evaluar el modelo tflite con el evaluate_tflite
método.
model.evaluate_tflite('model.tflite', test_data)
{'accuracy': 0.9019073569482289}
Uso avanzado
El create
función es la parte crítica de esta biblioteca. Utiliza el aprendizaje de transferencia con un modelo pretrained similar al tutorial .
El create
función contiene los siguientes pasos:
- Dividir los datos en el entrenamiento, validación, pruebas de datos de acuerdo con los parámetros
validation_ratio
ytest_ratio
. El valor por defecto devalidation_ratio
ytest_ratio
son0.1
y0.1
. - Descargar un vector de características de imagen como el modelo de base de TensorFlow concentradores. El modelo preentrenado predeterminado es EfficientNet-Lite0.
- Añadir una cabeza clasificador con una capa de deserción con
dropout_rate
entre la capa de la cabeza y el modelo de pre-formados. El valor por defectodropout_rate
es el valor predeterminadodropout_rate
valor de make_image_classifier_lib por TensorFlow concentradores. - Procese previamente los datos de entrada sin procesar. Actualmente, los pasos de preprocesamiento incluyen normalizar el valor de cada píxel de la imagen para modelar la escala de entrada y cambiar su tamaño al tamaño de entrada del modelo. EfficientNet-Lite0 tiene la escala de entrada
[0, 1]
y el tamaño de la imagen de entrada[224, 224, 3]
. - Introduzca los datos en el modelo de clasificador. Por defecto, los parámetros de entrenamiento tales como épocas de formación, tamaño del lote, la tasa de aprendizaje, el impulso son los valores por defecto de make_image_classifier_lib por TensorFlow concentradores. Solo se entrena la cabeza del clasificador.
En esta sección, describimos varios temas avanzados, incluido el cambio a un modelo de clasificación de imágenes diferente, el cambio de los hiperparámetros de entrenamiento, etc.
Personalice la cuantificación posterior al entrenamiento en el modelo TensorFLow Lite
Después de la formación de cuantificación es una técnica de conversión que puede reducir el tamaño del modelo y la latencia de la inferencia, además de mejorar la velocidad de la CPU y el acelerador de hardware inferencia, con un poco de degradación en la precisión del modelo. Por tanto, se utiliza mucho para optimizar el modelo.
La biblioteca Model Maker aplica una técnica de cuantificación posterior al entrenamiento predeterminada al exportar el modelo. Si desea personalizar la cuantificación posterior al entrenamiento, Modelo Maker soporta múltiples opciones después de la formación de cuantificación utilizando QuantizationConfig también. Tomemos como ejemplo la cuantificación de float16. Primero, defina la configuración de cuantificación.
config = QuantizationConfig.for_float16()
Luego exportamos el modelo de TensorFlow Lite con dicha configuración.
model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)
INFO:tensorflow:Assets written to: /tmp/tmpa528qeqj/assets INFO:tensorflow:Assets written to: /tmp/tmpa528qeqj/assets INFO:tensorflow:Label file is inside the TFLite model with metadata. 2021-11-02 11:43:43.724165: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-11-02 11:43:43.724219: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Saving labels in /tmp/tmpvlx_qa4j/labels.txt INFO:tensorflow:Saving labels in /tmp/tmpvlx_qa4j/labels.txt INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite
En Colab, se puede descargar el modelo llamado model_fp16.tflite
de la barra lateral izquierda, al igual que la parte de la carga mencionada anteriormente.
Cambiar el modelo
Cambie al modelo compatible con esta biblioteca.
Esta biblioteca es compatible con los modelos EfficientNet-Lite, MobileNetV2, ResNet50 por ahora. EfficientNet-Lite son una familia de modelos de clasificación de imágenes que podrían lograr exactitud el estado de técnica y adecuados para los dispositivos de borde. El modelo predeterminado es EfficientNet-Lite0.
Podríamos cambiar el modelo de MobileNetV2 por simplemente ajustando el parámetro model_spec
a la especificación del modelo MobileNetV2 en create
método.
model = image_classifier.create(train_data, model_spec=model_spec.get('mobilenet_v2'), validation_data=validation_data)
INFO:tensorflow:Retraining the models... INFO:tensorflow:Retraining the models... Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2_2 (HubK (None, 1280) 2257984 _________________________________________________________________ dropout_2 (Dropout) (None, 1280) 0 _________________________________________________________________ dense_2 (Dense) (None, 5) 6405 ================================================================= Total params: 2,264,389 Trainable params: 6,405 Non-trainable params: 2,257,984 _________________________________________________________________ None Epoch 1/5 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.") 91/91 [==============================] - 8s 53ms/step - loss: 0.9163 - accuracy: 0.7634 - val_loss: 0.7789 - val_accuracy: 0.8267 Epoch 2/5 91/91 [==============================] - 4s 50ms/step - loss: 0.6836 - accuracy: 0.8822 - val_loss: 0.7223 - val_accuracy: 0.8551 Epoch 3/5 91/91 [==============================] - 4s 50ms/step - loss: 0.6506 - accuracy: 0.9045 - val_loss: 0.7086 - val_accuracy: 0.8580 Epoch 4/5 91/91 [==============================] - 5s 50ms/step - loss: 0.6218 - accuracy: 0.9227 - val_loss: 0.7049 - val_accuracy: 0.8636 Epoch 5/5 91/91 [==============================] - 5s 52ms/step - loss: 0.6092 - accuracy: 0.9279 - val_loss: 0.7181 - val_accuracy: 0.8580
Evalúe el modelo MobileNetV2 recientemente reentrenado para ver la precisión y la pérdida en los datos de prueba.
loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 2s 35ms/step - loss: 0.6866 - accuracy: 0.8747
Cambiar al modelo en TensorFlow Hub
Además, también podríamos cambiar a otros modelos nuevos que ingresen una imagen y generen un vector de características con el formato TensorFlow Hub.
Como Inception V3 modelo como un ejemplo, podríamos definir inception_v3_spec
que es un objeto de image_classifier.ModelSpec y contiene la especificación del modelo Inception V3.
Tenemos que especificar el nombre del modelo name
, la dirección URL del modelo TensorFlow Hub uri
. Mientras tanto, el valor por defecto de input_image_shape
es [224, 224]
. Tenemos que cambiarlo a [299, 299]
para el modelo de entrada en vigor V3.
inception_v3_spec = image_classifier.ModelSpec(
uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
A continuación, mediante el ajuste de parámetros model_spec
a inception_v3_spec
en create
método, podríamos entrenar el modelo Inception V3.
Los pasos restantes son exactamente los mismos y al final podríamos obtener un modelo InceptionV3 TensorFlow Lite personalizado.
Cambia tu propio modelo personalizado
Si nos gustaría utilizar el modelo personalizado que no está en TensorFlow Hub, debemos crear y exportar ModelSpec en TensorFlow concentradores.
A continuación, empezar a definir ModelSpec
objeto como el proceso anterior.
Cambiar los hiperparámetros de entrenamiento
También podríamos cambiar los hiperparámetros formación como epochs
, dropout_rate
y batch_size
que podrían afectar a la precisión del modelo. Los parámetros del modelo que puede ajustar son:
-
epochs
: más épocas podrían lograr una mayor precisión hasta que converge pero el entrenamiento para demasiadas épocas puede dar lugar a un ajuste por exceso. -
dropout_rate
: La tasa de deserción, sobreajuste evitar. Ninguno por defecto. -
batch_size
: número de muestras a utilizar en un solo paso de formación. Ninguno por defecto. -
validation_data
datos de validación:. Si es Ninguno, omite el proceso de validación. Ninguno por defecto. -
train_whole_model
: Si es verdad, el módulo de concentradores se entrenó junto con la capa de clasificación en la parte superior. De lo contrario, entrene solo la capa de clasificación superior. Ninguno por defecto. -
learning_rate
: Base tasa de aprendizaje. Ninguno por defecto. -
momentum
: un flotador Python remitido al optimizador. Sólo se utiliza cuandouse_hub_library
es cierto. Ninguno por defecto. -
shuffle
: booleano, si los datos deben ser mezcladas. Falso por defecto. -
use_augmentation
: Boolean, aumento de uso de los datos para el preprocesamiento. Falso por defecto. -
use_hub_library
: Boolean, el usomake_image_classifier_lib
desde el cubo tensorflow volver a entrenar el modelo. Esta canalización de capacitación podría lograr un mejor rendimiento para conjuntos de datos complicados con muchas categorías. Verdadero por defecto. -
warmup_steps
: Número de pasos de calentamiento para la programación de calentamiento en la tasa de aprendizaje. Si es Ninguno, se usa el warmup_steps predeterminado, que es el total de pasos de entrenamiento en dos épocas. Sólo se utiliza cuandouse_hub_library
es falso. Ninguno por defecto. -
model_dir
: Opcional, la ubicación de los archivos de modelo de punto de control. Sólo se utiliza cuandouse_hub_library
es falso. Ninguno por defecto.
Los parámetros que hay ninguno por defecto como epochs
obtendrán los parámetros por defecto de hormigón en make_image_classifier_lib de la biblioteca TensorFlow Hub o train_image_classifier_lib .
Por ejemplo, podríamos entrenar con más épocas.
model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)
INFO:tensorflow:Retraining the models... INFO:tensorflow:Retraining the models... Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2_3 (HubK (None, 1280) 3413024 _________________________________________________________________ dropout_3 (Dropout) (None, 1280) 0 _________________________________________________________________ dense_3 (Dense) (None, 5) 6405 ================================================================= Total params: 3,419,429 Trainable params: 6,405 Non-trainable params: 3,413,024 _________________________________________________________________ None Epoch 1/10 91/91 [==============================] - 6s 53ms/step - loss: 0.8735 - accuracy: 0.7644 - val_loss: 0.6701 - val_accuracy: 0.8892 Epoch 2/10 91/91 [==============================] - 4s 49ms/step - loss: 0.6502 - accuracy: 0.8984 - val_loss: 0.6442 - val_accuracy: 0.8864 Epoch 3/10 91/91 [==============================] - 4s 49ms/step - loss: 0.6215 - accuracy: 0.9107 - val_loss: 0.6306 - val_accuracy: 0.8920 Epoch 4/10 91/91 [==============================] - 4s 49ms/step - loss: 0.5962 - accuracy: 0.9299 - val_loss: 0.6253 - val_accuracy: 0.8977 Epoch 5/10 91/91 [==============================] - 5s 52ms/step - loss: 0.5845 - accuracy: 0.9334 - val_loss: 0.6206 - val_accuracy: 0.9062 Epoch 6/10 91/91 [==============================] - 5s 50ms/step - loss: 0.5743 - accuracy: 0.9451 - val_loss: 0.6159 - val_accuracy: 0.9062 Epoch 7/10 91/91 [==============================] - 4s 48ms/step - loss: 0.5682 - accuracy: 0.9444 - val_loss: 0.6192 - val_accuracy: 0.9006 Epoch 8/10 91/91 [==============================] - 4s 49ms/step - loss: 0.5595 - accuracy: 0.9557 - val_loss: 0.6153 - val_accuracy: 0.9091 Epoch 9/10 91/91 [==============================] - 4s 47ms/step - loss: 0.5560 - accuracy: 0.9523 - val_loss: 0.6213 - val_accuracy: 0.9062 Epoch 10/10 91/91 [==============================] - 4s 45ms/step - loss: 0.5520 - accuracy: 0.9595 - val_loss: 0.6220 - val_accuracy: 0.8977
Evalúe el modelo recién reentrenado con 10 épocas de entrenamiento.
loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 27ms/step - loss: 0.6417 - accuracy: 0.8883
Lee mas
Usted puede leer nuestra clasificación de imágenes ejemplo para aprender detalles técnicos. Para obtener más información, consulte:
- TensorFlow Lite Modelo Fabricante de guía y referencia de la API .
- Biblioteca de tareas: ImageClassifier para el despliegue.
- La referencia de extremo a extremo Aplicaciones: Android , iOS , y frambuesa PI .