Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Reentrenamiento de un clasificador de imágenes

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno Ver modelo TF Hub

Introducción

Los modelos de clasificación de imágenes tienen millones de parámetros. Entrenarlos desde cero requiere una gran cantidad de datos de entrenamiento etiquetados y mucha potencia informática. El aprendizaje por transferencia es una técnica que ataja gran parte de esto tomando una parte de un modelo que ya ha sido entrenado en una tarea relacionada y reutilizándolo en un nuevo modelo.

Este Colab demuestra cómo construir un modelo de Keras para clasificar cinco especies de flores mediante el uso de un modelo guardado TF2 previamente entrenado de TensorFlow Hub para la extracción de características de imagen, entrenado en el conjunto de datos ImageNet mucho más grande y general. Opcionalmente, el extractor de características se puede entrenar ("ajustar") junto con el clasificador recién agregado.

¿Está buscando una herramienta en su lugar?

Este es un tutorial de codificación de TensorFlow. Si desea una herramienta que solo cree el modelo de TensorFlow o TF Lite, eche un vistazo a la herramienta de línea de comandos make_image_classifier que se instala mediante el paquete PIP tensorflow-hub[make_image_classifier] , o en este tensorflow-hub[make_image_classifier] TF Lite.

Preparar

import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")
TF version: 2.3.1
Hub version: 0.9.0
WARNING:tensorflow:From <ipython-input-1-0831fa394ed3>:12: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
GPU is available

Seleccione el módulo TF2 SavedModel para usar

Para empezar, use https: //tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4 . La misma URL se puede utilizar en el código para identificar el modelo guardado y en su navegador para mostrar su documentación. (Tenga en cuenta que los modelos en formato TF1 Hub no funcionarán aquí).

Puede encontrar más modelos TF2 que generan vectores de características de imagen aquí .

module_selection = ("mobilenet_v2_100_224", 224) 
handle_base, pixels = module_selection
MODULE_HANDLE ="https://tfhub.dev/google/imagenet/{}/feature_vector/4".format(handle_base)
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

BATCH_SIZE = 32 
Using https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4 with input size (224, 224)

Configurar el conjunto de datos de Flowers

Las entradas se redimensionan adecuadamente para el módulo seleccionado. El aumento del conjunto de datos (es decir, distorsiones aleatorias de una imagen cada vez que se lee) mejora el entrenamiento, especialmente. al realizar ajustes finos.

data_dir = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 3s 0us/step

datagen_kwargs = dict(rescale=1./255, validation_split=.20)
dataflow_kwargs = dict(target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
                   interpolation="bilinear")

valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    **datagen_kwargs)
valid_generator = valid_datagen.flow_from_directory(
    data_dir, subset="validation", shuffle=False, **dataflow_kwargs)

do_data_augmentation = False 
if do_data_augmentation:
  train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
      rotation_range=40,
      horizontal_flip=True,
      width_shift_range=0.2, height_shift_range=0.2,
      shear_range=0.2, zoom_range=0.2,
      **datagen_kwargs)
else:
  train_datagen = valid_datagen
train_generator = train_datagen.flow_from_directory(
    data_dir, subset="training", shuffle=True, **dataflow_kwargs)
Found 731 images belonging to 5 classes.
Found 2939 images belonging to 5 classes.

Definiendo el modelo

Todo lo que se necesita es colocar un clasificador lineal encima de feature_extractor_layer con el módulo Hub.

Para la velocidad, comenzamos con un feature_extractor_layer no entrenable, pero también puede habilitar el ajuste fino para una mayor precisión.

do_fine_tuning = False 
print("Building model with", MODULE_HANDLE)
model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(MODULE_HANDLE, trainable=do_fine_tuning),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(train_generator.num_classes,
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()
Building model with https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 1280)              2257984   
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

Entrenando el modelo

model.compile(
  optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  metrics=['accuracy'])
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
hist = model.fit(
    train_generator,
    epochs=5, steps_per_epoch=steps_per_epoch,
    validation_data=valid_generator,
    validation_steps=validation_steps).history
Epoch 1/5
91/91 [==============================] - 16s 171ms/step - loss: 0.9733 - accuracy: 0.7286 - val_loss: 0.7463 - val_accuracy: 0.8395
Epoch 2/5
91/91 [==============================] - 15s 160ms/step - loss: 0.6944 - accuracy: 0.8762 - val_loss: 0.7031 - val_accuracy: 0.8665
Epoch 3/5
91/91 [==============================] - 14s 159ms/step - loss: 0.6602 - accuracy: 0.8934 - val_loss: 0.7243 - val_accuracy: 0.8580
Epoch 4/5
91/91 [==============================] - 15s 162ms/step - loss: 0.6276 - accuracy: 0.9143 - val_loss: 0.7081 - val_accuracy: 0.8665
Epoch 5/5
91/91 [==============================] - 15s 160ms/step - loss: 0.6115 - accuracy: 0.9254 - val_loss: 0.6772 - val_accuracy: 0.8864

plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"])
plt.plot(hist["val_loss"])

plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])
[<matplotlib.lines.Line2D at 0x7fb5f7ee4978>]

png

png

Pruebe el modelo en una imagen de los datos de validación:

def get_class_string_from_index(index):
   for class_string, class_index in valid_generator.class_indices.items():
      if class_index == index:
         return class_string

x, y = next(valid_generator)
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
plt.axis('off')
plt.show()

# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(predicted_index))

png

True label: daisy
Predicted label: daisy

Finalmente, el modelo entrenado se puede guardar para implementarlo en TF Serving o TF Lite (en dispositivos móviles) de la siguiente manera.

saved_model_path = "/tmp/saved_flowers_model"
tf.saved_model.save(model, saved_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/saved_flowers_model/assets

INFO:tensorflow:Assets written to: /tmp/saved_flowers_model/assets

Opcional: implementación en TensorFlow Lite

TensorFlow Lite te permite implementar modelos de TensorFlow en dispositivos móviles y de IoT. El siguiente código muestra cómo convertir el modelo entrenado a TF Lite y aplicar las herramientas posteriores al entrenamiento del kit de herramientas de optimización de modelos de TensorFlow . Finalmente, lo ejecuta en TF Lite Interpreter para examinar la calidad resultante

  • La conversión sin optimización proporciona los mismos resultados que antes (hasta el error de redondeo).
  • La conversión con optimización sin ningún dato cuantifica los pesos del modelo a 8 bits, pero la inferencia aún utiliza el cálculo de punto flotante para las activaciones de la red neuronal. Esto reduce el tamaño del modelo casi en un factor de 4 y mejora la latencia de la CPU en los dispositivos móviles.
  • Además, el cálculo de las activaciones de la red neuronal también se puede cuantificar en números enteros de 8 bits si se proporciona un pequeño conjunto de datos de referencia para calibrar el rango de cuantificación. En un dispositivo móvil, esto acelera aún más la inferencia y hace posible la ejecución en aceleradores como EdgeTPU.

# TODO(b/156102192)
optimize_lite_model = False  

num_calibration_examples = 60  
representative_dataset = None
if optimize_lite_model and num_calibration_examples:
  # Use a bounded number of training examples without labels for calibration.
  # TFLiteConverter expects a list of input tensors, each with batch size 1.
  representative_dataset = lambda: itertools.islice(
      ([image[None, ...]] for batch, _ in train_generator for image in batch),
      num_calibration_examples)

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
if optimize_lite_model:
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
  if representative_dataset:  # This is optional, see above.
    converter.representative_dataset = representative_dataset
lite_model_content = converter.convert()

with open("/tmp/lite_flowers_model", "wb") as f:
  f.write(lite_model_content)
print("Wrote %sTFLite model of %d bytes." %
      ("optimized " if optimize_lite_model else "", len(lite_model_content)))
interpreter = tf.lite.Interpreter(model_content=lite_model_content)
# This little helper wraps the TF Lite interpreter as a numpy-to-numpy function.
def lite_model(images):
  interpreter.allocate_tensors()
  interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
  interpreter.invoke()
  return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])

num_eval_examples = 50  
eval_dataset = ((image, label)  # TFLite expects batch size 1.
                for batch in train_generator
                for (image, label) in zip(*batch))
count = 0
count_lite_tf_agree = 0
count_lite_correct = 0
for image, label in eval_dataset:
  probs_lite = lite_model(image[None, ...])[0]
  probs_tf = model(image[None, ...]).numpy()[0]
  y_lite = np.argmax(probs_lite)
  y_tf = np.argmax(probs_tf)
  y_true = np.argmax(label)
  count +=1
  if y_lite == y_tf: count_lite_tf_agree += 1
  if y_lite == y_true: count_lite_correct += 1
  if count >= num_eval_examples: break
print("TF Lite model agrees with original model on %d of %d examples (%g%%)." %
      (count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))
print("TF Lite model is accurate on %d of %d examples (%g%%)." %
      (count_lite_correct, count, 100.0 * count_lite_correct / count))