Ajuda a proteger a Grande Barreira de Corais com TensorFlow em Kaggle Junte Desafio

Classificação de imagem com TensorFlow Lite Model Maker

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno Veja o modelo TF Hub

A biblioteca Criador TensorFlow Lite Modelo simplifica o processo de adaptação e conversão de um modelo neural-network TensorFlow a determinados dados de entrada ao implantar este modelo para aplicações ML no dispositivo.

Este bloco de notas mostra um exemplo de ponta a ponta que utiliza esta biblioteca Model Maker para ilustrar a adaptação e conversão de um modelo de classificação de imagem comumente usado para classificar flores em um dispositivo móvel.

Pré-requisitos

Para executar esse exemplo, primeiro é necessário instalar vários pacotes necessários, incluindo pacote Fabricante Modelo que no GitHub repo .

pip install -q tflite-model-maker

Importe os pacotes necessários.

import os

import numpy as np

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader

import matplotlib.pyplot as plt
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release
  PkgResourcesDeprecationWarning,
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/numba/core/errors.py:168: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
  warnings.warn(msg)

Exemplo Simples de Ponta a Ponta

Obtenha o caminho dos dados

Vamos pegar algumas imagens para brincar com este exemplo simples de ponta a ponta. Centenas de imagens é um bom começo para o Model Maker, enquanto mais dados podem alcançar melhor precisão.

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 1s 0us/step
228827136/228813984 [==============================] - 1s 0us/step

Você poderia substituir image_path com suas próprias pastas de imagem. Quanto ao upload de dados para o colab, você pode encontrar o botão de upload na barra lateral esquerda mostrado na imagem abaixo com o retângulo vermelho. Tente fazer o upload de um arquivo zip e descompacte-o. O caminho do arquivo raiz é o caminho atual.

Subir arquivo

Se você preferir não fazer o upload de suas imagens para a nuvem, você pode tentar executar a biblioteca local seguindo o guia no GitHub.

Execute o exemplo

O exemplo consiste apenas em 4 linhas de código conforme mostrado abaixo, cada uma representando uma etapa do processo geral.

Etapa 1. Carregar dados de entrada específicos para um aplicativo ML no dispositivo. Divida-o em dados de treinamento e dados de teste.

data = DataLoader.from_folder(image_path)
train_data, test_data = data.split(0.9)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.

Etapa 2. Personalize o modelo do TensorFlow.

model = image_classifier.create(train_data)
INFO:tensorflow:Retraining the models...
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2 (HubKer (None, 1280)              3413024   
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
103/103 [==============================] - 7s 35ms/step - loss: 0.8551 - accuracy: 0.7718
Epoch 2/5
103/103 [==============================] - 4s 35ms/step - loss: 0.6503 - accuracy: 0.8956
Epoch 3/5
103/103 [==============================] - 4s 34ms/step - loss: 0.6157 - accuracy: 0.9196
Epoch 4/5
103/103 [==============================] - 3s 33ms/step - loss: 0.6036 - accuracy: 0.9293
Epoch 5/5
103/103 [==============================] - 4s 34ms/step - loss: 0.5929 - accuracy: 0.9317

Etapa 3. Avalie o modelo.

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 2s 40ms/step - loss: 0.6282 - accuracy: 0.9019

Etapa 4. Exportar para o modelo TensorFlow Lite.

Aqui, nós exportamos modelo TensorFlow Lite com metadados que fornece um padrão para descrição de modelo. O arquivo de etiqueta é embutido em metadados. A técnica de quantização pós-treinamento padrão é a quantização de número inteiro completo para a tarefa de classificação de imagens.

Você pode baixá-lo na barra lateral esquerda da mesma forma que a parte de upload para seu próprio uso.

model.export(export_dir='.')
2021-11-02 11:34:05.568024: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/tmpkqikzotp/assets
INFO:tensorflow:Assets written to: /tmp/tmpkqikzotp/assets
2021-11-02 11:34:09.488041: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-11-02 11:34:09.488090: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3
WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmp/tmpoblx4ed5/labels.txt
INFO:tensorflow:Saving labels in /tmp/tmpoblx4ed5/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite

Após estas simples 4 passos, poderíamos usar mais arquivo de modelo TensorFlow Lite em aplicações no dispositivo como na imagem classificação aplicativo de referência.

Processo Detalhado

Atualmente, oferecemos suporte a vários modelos como os modelos EfficientNet-Lite *, MobileNetV2, ResNet50 como modelos pré-treinados para classificação de imagens. Mas é muito flexível adicionar novos modelos pré-treinados a esta biblioteca com apenas algumas linhas de código.

A seguir, apresentamos este exemplo de ponta a ponta, passo a passo, para mostrar mais detalhes.

Etapa 1: carregar dados de entrada específicos para um aplicativo de ML no dispositivo

O conjunto de dados de flores contém 3670 imagens pertencentes a 5 classes. Baixe a versão do arquivo do conjunto de dados e descompacte-o.

O conjunto de dados tem a seguinte estrutura de diretório:

flower_photos
|__ daisy
    |______ 100080576_f52e8ee070_n.jpg
    |______ 14167534527_781ceb1b7a_n.jpg
    |______ ...
|__ dandelion
    |______ 10043234166_e6dd915111_n.jpg
    |______ 1426682852_e62169221f_m.jpg
    |______ ...
|__ roses
    |______ 102501987_3cdb8e5394_n.jpg
    |______ 14982802401_a3dfb22afb.jpg
    |______ ...
|__ sunflowers
    |______ 12471791574_bb1be83df4.jpg
    |______ 15122112402_cafa41934f.jpg
    |______ ...
|__ tulips
    |______ 13976522214_ccec508fe7.jpg
    |______ 14487943607_651e8062a1_m.jpg
    |______ ...
image_path = tf.keras.utils.get_file(
      'flower_photos.tgz',
      'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
      extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

Use DataLoader classe para carregar dados.

Quanto from_folder() método, pode carregar os dados a partir da pasta. Ele assume que os dados de imagem da mesma classe estão no mesmo subdiretório e o nome da subpasta é o nome da classe. Atualmente, imagens codificadas em JPEG e imagens codificadas em PNG são suportadas.

data = DataLoader.from_folder(image_path)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.

Divida em dados de treinamento (80%), dados de validação (10%, opcional) e dados de teste (10%).

train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)

Mostra 25 exemplos de imagem com rótulos.

plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)
  plt.xlabel(data.index_to_label[label.numpy()])
plt.show()

png

Etapa 2: personalizar o modelo do TensorFlow

Crie um modelo de classificador de imagem personalizado com base nos dados carregados. O modelo padrão é EfficientNet-Lite0.

model = image_classifier.create(train_data, validation_data=validation_data)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2_1 (HubK (None, 1280)              3413024   
_________________________________________________________________
dropout_1 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 6405      
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
91/91 [==============================] - 6s 54ms/step - loss: 0.8689 - accuracy: 0.7655 - val_loss: 0.6941 - val_accuracy: 0.8835
Epoch 2/5
91/91 [==============================] - 5s 50ms/step - loss: 0.6596 - accuracy: 0.8949 - val_loss: 0.6668 - val_accuracy: 0.8807
Epoch 3/5
91/91 [==============================] - 5s 50ms/step - loss: 0.6188 - accuracy: 0.9159 - val_loss: 0.6537 - val_accuracy: 0.8807
Epoch 4/5
91/91 [==============================] - 5s 52ms/step - loss: 0.6050 - accuracy: 0.9210 - val_loss: 0.6432 - val_accuracy: 0.8892
Epoch 5/5
91/91 [==============================] - 5s 52ms/step - loss: 0.5898 - accuracy: 0.9348 - val_loss: 0.6348 - val_accuracy: 0.8864

Dê uma olhada na estrutura detalhada do modelo.

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2_1 (HubK (None, 1280)              3413024   
_________________________________________________________________
dropout_1 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 6405      
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________

Etapa 3: Avalie o modelo personalizado

Avalie o resultado do modelo, obtenha a perda e a precisão do modelo.

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 27ms/step - loss: 0.6324 - accuracy: 0.8965

Podemos traçar os resultados previstos em 100 imagens de teste. Rótulos previstos com a cor vermelha são os resultados previstos incorretos, enquanto outros estão corretos.

# A helper function that returns 'red'/'black' depending on if its two input
# parameter matches or not.
def get_label_color(val1, val2):
  if val1 == val2:
    return 'black'
  else:
    return 'red'

# Then plot 100 test images and their predicted labels.
# If a prediction result is different from the label provided label in "test"
# dataset, we will highlight it in red color.
plt.figure(figsize=(20, 20))
predicts = model.predict_top_k(test_data)
for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(100)):
  ax = plt.subplot(10, 10, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)

  predict_label = predicts[i][0][0]
  color = get_label_color(predict_label,
                          test_data.index_to_label[label.numpy()])
  ax.xaxis.label.set_color(color)
  plt.xlabel('Predicted: %s' % predict_label)
plt.show()

png

Se a precisão não cumprir a exigência de aplicação, pode-se referir a Uso Avançado para explorar alternativas, tais como a mudança para um modelo maior, ajustando os parâmetros re-treinamento etc.

Etapa 4: exportar para o modelo TensorFlow Lite

Converter o modelo treinado para o formato modelo TensorFlow Lite com metadados de modo que você pode usar mais tarde em um aplicativo ML no dispositivo. O arquivo de etiqueta e o arquivo de vocabulário são incorporados aos metadados. O nome do arquivo TFLite padrão é model.tflite .

Em muitos aplicativos de ML no dispositivo, o tamanho do modelo é um fator importante. Portanto, é recomendado que você aplique quantizar o modelo para torná-lo menor e potencialmente rodar mais rápido. A técnica de quantização pós-treinamento padrão é a quantização de número inteiro completo para a tarefa de classificação de imagens.

model.export(export_dir='.')
INFO:tensorflow:Assets written to: /tmp/tmp6tt5g8de/assets
INFO:tensorflow:Assets written to: /tmp/tmp6tt5g8de/assets
2021-11-02 11:35:40.254046: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-11-02 11:35:40.254099: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3
WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmp/tmpf601xty1/labels.txt
INFO:tensorflow:Saving labels in /tmp/tmpf601xty1/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite

Ver aplicações e guias de classificação de imagens de exemplo para obter mais detalhes sobre como integrar o modelo TensorFlow Lite em aplicativos móveis.

Este modelo pode ser integrado em um Android ou um app iOS usando o API ImageClassifier da Biblioteca de Tarefas Lite TensorFlow .

Os formatos de exportação permitidos podem ser um ou uma lista dos seguintes:

Por padrão, ele apenas exporta o modelo TensorFlow Lite com metadados. Você também pode exportar arquivos diferentes seletivamente. Por exemplo, exportando apenas o arquivo de etiqueta da seguinte forma:

model.export(export_dir='.', export_format=ExportFormat.LABEL)
INFO:tensorflow:Saving labels in ./labels.txt
INFO:tensorflow:Saving labels in ./labels.txt

Você também pode avaliar o modelo tflite com o evaluate_tflite método.

model.evaluate_tflite('model.tflite', test_data)
{'accuracy': 0.9019073569482289}

Uso Avançado

A create função é a parte crítica desta biblioteca. Ele usa o aprendizado de transferência com um modelo pré-treinado semelhante ao tutorial .

A create função contém os seguintes passos:

  1. Dividir os dados em formação, validação, testes de dados de acordo com o parâmetro validation_ratio e test_ratio . O valor padrão de validation_ratio e test_ratio são 0.1 e 0.1 .
  2. Baixar um vetor de imagens características como o modelo base TensorFlow Hub. O modelo pré-treinado padrão é EfficientNet-Lite0.
  3. Adicionar uma cabeça classificador com uma camada Dropout com dropout_rate entre a camada de cabeça e modelo pré-treinados. O padrão dropout_rate é o padrão dropout_rate valor de make_image_classifier_lib por TensorFlow Hub.
  4. Pré-processe os dados de entrada brutos. Atualmente, as etapas de pré-processamento incluem normalizar o valor de cada pixel da imagem para modelar a escala de entrada e redimensioná-la para modelar o tamanho de entrada. EfficientNet-Lite0 tem a escala de entrada [0, 1] e o tamanho da imagem de entrada [224, 224, 3] .
  5. Alimente os dados no modelo do classificador. Por padrão, os parâmetros de formação, tais como épocas de treinamento, tamanho do lote, taxa de aprendizagem, momento são os valores padrão de make_image_classifier_lib por TensorFlow Hub. Apenas o cabeçote do classificador é treinado.

Nesta seção, descrevemos vários tópicos avançados, incluindo alternar para um modelo de classificação de imagem diferente, alterar os hiperparâmetros de treinamento etc.

Personalize a quantização pós-treinamento no modelo TensorFLow Lite

Quantização pós-treino é uma técnica de conversão que pode reduzir o tamanho do modelo e latência inferência, ao mesmo tempo melhorar a velocidade da CPU e acelerador de hardware inferência, com um pouco de degradação na precisão do modelo. Portanto, é amplamente utilizado para otimizar o modelo.

A biblioteca do Model Maker aplica uma técnica de quantização pós-treinamento padrão ao exportar o modelo. Se você quiser personalizar quantização pós-treino, Model Maker suporta múltiplas opções de pós-formação de quantização usando QuantizationConfig também. Vamos tomar a quantização float16 como uma instância. Primeiro, defina a configuração de quantização.

config = QuantizationConfig.for_float16()

Em seguida, exportamos o modelo TensorFlow Lite com essa configuração.

model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)
INFO:tensorflow:Assets written to: /tmp/tmpa528qeqj/assets
INFO:tensorflow:Assets written to: /tmp/tmpa528qeqj/assets
INFO:tensorflow:Label file is inside the TFLite model with metadata.
2021-11-02 11:43:43.724165: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-11-02 11:43:43.724219: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmp/tmpvlx_qa4j/labels.txt
INFO:tensorflow:Saving labels in /tmp/tmpvlx_qa4j/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite

Em Colab, você pode baixar o modelo chamado model_fp16.tflite da barra lateral esquerda, mesmo que a parte upload mencionado acima.

Mudar o modelo

Mude para o modelo que é compatível com esta biblioteca.

Esta biblioteca é compatível com os modelos EfficientNet-Lite, MobileNetV2, ResNet50 agora. EfficientNet-Lite são uma família de modelos de classificação de imagem que pode atingir precisão o estado-da-arte e adequado para dispositivos de borda. O modelo padrão é EfficientNet-Lite0.

Poderíamos mudar modelo para MobileNetV2 por apenas ajustando o parâmetro model_spec à especificação do modelo MobileNetV2 em create método.

model = image_classifier.create(train_data, model_spec=model_spec.get('mobilenet_v2'), validation_data=validation_data)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2_2 (HubK (None, 1280)              2257984   
_________________________________________________________________
dropout_2 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
None
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
91/91 [==============================] - 8s 53ms/step - loss: 0.9163 - accuracy: 0.7634 - val_loss: 0.7789 - val_accuracy: 0.8267
Epoch 2/5
91/91 [==============================] - 4s 50ms/step - loss: 0.6836 - accuracy: 0.8822 - val_loss: 0.7223 - val_accuracy: 0.8551
Epoch 3/5
91/91 [==============================] - 4s 50ms/step - loss: 0.6506 - accuracy: 0.9045 - val_loss: 0.7086 - val_accuracy: 0.8580
Epoch 4/5
91/91 [==============================] - 5s 50ms/step - loss: 0.6218 - accuracy: 0.9227 - val_loss: 0.7049 - val_accuracy: 0.8636
Epoch 5/5
91/91 [==============================] - 5s 52ms/step - loss: 0.6092 - accuracy: 0.9279 - val_loss: 0.7181 - val_accuracy: 0.8580

Avalie o modelo MobileNetV2 recém-treinado para ver a precisão e a perda nos dados de teste.

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 2s 35ms/step - loss: 0.6866 - accuracy: 0.8747

Mudança para o modelo no TensorFlow Hub

Além disso, também poderíamos mudar para outros novos modelos que inserem uma imagem e geram um vetor de recursos com o formato TensorFlow Hub.

Como Iniciação V3 modelo como um exemplo, podemos definir inception_v3_spec que é um objecto do image_classifier.ModelSpec e contém a especificação do modelo de Iniciação V3.

Precisamos especificar o nome do modelo name , o URL do modelo TensorFlow Hub uri . Enquanto isso, o valor padrão de input_image_shape é [224, 224] . Precisamos mudá-lo para [299, 299] para o modelo Inception V3.

inception_v3_spec = image_classifier.ModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]

Então, por parametrização model_spec para inception_v3_spec em create método, poderíamos voltar a treinar o modelo Inception V3.

As etapas restantes são exatamente as mesmas e poderíamos obter um modelo personalizado do InceptionV3 TensorFlow Lite no final.

Mude seu próprio modelo personalizado

Se nós gostaríamos de usar o modelo personalizado que não está em TensorFlow Hub, devemos criar e exportar ModelSpec em TensorFlow Hub.

Em seguida, começa a definir ModelSpec objecto parecido com o processo acima.

Alterar os hiperparâmetros de treinamento

Nós também poderia mudar os hiperparâmetros de treinamento como epochs , dropout_rate e batch_size que poderiam afetar a precisão do modelo. Os parâmetros do modelo que você pode ajustar são:

  • epochs : mais épocas poderia conseguir uma melhor precisão até que converge mas o treinamento para muitas épocas pode levar a overfitting.
  • dropout_rate : A taxa de abandono, overfitting evitar. Nenhum por padrão.
  • batch_size : número de amostras a utilizar em uma etapa de treinamento. Nenhum por padrão.
  • validation_data Validação de dados:. Se nenhum, ignora o processo de validação. Nenhum por padrão.
  • train_whole_model : Se for verdade, o módulo de Hub é treinado em conjunto com a camada de classificação no topo. Caso contrário, treine apenas a camada de classificação superior. Nenhum por padrão.
  • learning_rate : taxa de aprendizagem Base. Nenhum por padrão.
  • momentum : uma bóia Python encaminhado para o otimizador. Apenas usada quando use_hub_library é True. Nenhum por padrão.
  • shuffle : booleano, se os dados devem ser embaralhadas. False por padrão.
  • use_augmentation : Boolean, o aumento do uso de dados para pré-processamento. False por padrão.
  • use_hub_library : booleana, uso make_image_classifier_lib de cubo tensorflow para reconverter o modelo. Este pipeline de treinamento pode alcançar melhor desempenho para conjuntos de dados complicados com muitas categorias. Verdadeiro por padrão.
  • warmup_steps : Número de passos de aquecimento para programar configurações de aquecimento na taxa de aprendizagem. Se Nenhum, o warmup_steps padrão é usado, que é o total de etapas de treinamento em duas épocas. Apenas usada quando use_hub_library é False. Nenhum por padrão.
  • model_dir : Opcional, a localização dos arquivos de modelo de ponto de verificação. Apenas usada quando use_hub_library é False. Nenhum por padrão.

Parâmetros que são None por padrão como epochs irá obter os parâmetros padrão de concreto em make_image_classifier_lib da biblioteca TensorFlow Hub ou train_image_classifier_lib .

Por exemplo, podemos treinar com mais épocas.

model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2_3 (HubK (None, 1280)              3413024   
_________________________________________________________________
dropout_3 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 5)                 6405      
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/10
91/91 [==============================] - 6s 53ms/step - loss: 0.8735 - accuracy: 0.7644 - val_loss: 0.6701 - val_accuracy: 0.8892
Epoch 2/10
91/91 [==============================] - 4s 49ms/step - loss: 0.6502 - accuracy: 0.8984 - val_loss: 0.6442 - val_accuracy: 0.8864
Epoch 3/10
91/91 [==============================] - 4s 49ms/step - loss: 0.6215 - accuracy: 0.9107 - val_loss: 0.6306 - val_accuracy: 0.8920
Epoch 4/10
91/91 [==============================] - 4s 49ms/step - loss: 0.5962 - accuracy: 0.9299 - val_loss: 0.6253 - val_accuracy: 0.8977
Epoch 5/10
91/91 [==============================] - 5s 52ms/step - loss: 0.5845 - accuracy: 0.9334 - val_loss: 0.6206 - val_accuracy: 0.9062
Epoch 6/10
91/91 [==============================] - 5s 50ms/step - loss: 0.5743 - accuracy: 0.9451 - val_loss: 0.6159 - val_accuracy: 0.9062
Epoch 7/10
91/91 [==============================] - 4s 48ms/step - loss: 0.5682 - accuracy: 0.9444 - val_loss: 0.6192 - val_accuracy: 0.9006
Epoch 8/10
91/91 [==============================] - 4s 49ms/step - loss: 0.5595 - accuracy: 0.9557 - val_loss: 0.6153 - val_accuracy: 0.9091
Epoch 9/10
91/91 [==============================] - 4s 47ms/step - loss: 0.5560 - accuracy: 0.9523 - val_loss: 0.6213 - val_accuracy: 0.9062
Epoch 10/10
91/91 [==============================] - 4s 45ms/step - loss: 0.5520 - accuracy: 0.9595 - val_loss: 0.6220 - val_accuracy: 0.8977

Avalie o modelo recém-treinado com 10 épocas de treinamento.

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 27ms/step - loss: 0.6417 - accuracy: 0.8883

Consulte Mais informação

Você pode ler a nossa imagem de classificação de exemplo para aprender detalhes técnicos. Para obter mais informações, consulte: