Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Transferir el aprendizaje con TensorFlow Hub

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno Ver modelo TF Hub

TensorFlow Hub es un repositorio de modelos de TensorFlow entrenados previamente.

Este tutorial demuestra cómo:

  1. Usa modelos de TensorFlow Hub con tf.keras
  2. Usa un modelo de clasificación de imágenes de TensorFlow Hub
  3. Realice un aprendizaje de transferencia simple para ajustar un modelo para sus propias clases de imágenes

Preparar

import numpy as np
import time

import PIL.Image as Image
import matplotlib.pylab as plt

import tensorflow as tf
import tensorflow_hub as hub

Un clasificador de ImageNet

Comenzará utilizando un modelo de clasificador previamente entrenado para tomar una imagen y predecir de qué es una imagen, ¡no se requiere capacitación!

Descarga el clasificador

Use hub.KerasLayer para cargar un modelo MobileNetV2 desde TensorFlow Hub. Cualquier modelo de clasificador de imágenes compatible de tfhub.dev funcionará aquí.

classifier_model ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4" 
IMAGE_SHAPE = (224, 224)

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))
])

Ejecútelo en una sola imagen

Descarga una sola imagen para probar el modelo.

grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg
65536/61306 [================================] - 0s 0us/step

png

grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
(224, 224, 3)

Agregue una dimensión de lote y pase la imagen al modelo.

result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)

El resultado es un vector de logits de 1001 elementos, que califica la probabilidad de cada clase para la imagen.

Entonces, la ID de clase superior se puede encontrar con argmax:

predicted_class = np.argmax(result[0], axis=-1)
predicted_class
653

Decodificar las predicciones

Tome el ID de clase predicho y ImageNet etiquetas de ImageNet para decodificar las predicciones

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
16384/10484 [==============================================] - 0s 0us/step

plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())

png

Aprendizaje de transferencia simple

Pero, ¿qué sucede si desea entrenar un clasificador para un conjunto de datos con diferentes clases? También puede usar un modelo de TFHub para entrenar a un clasificador de imágenes personalizado volviendo a entrenar la capa superior del modelo para reconocer las clases en nuestro conjunto de datos.

Conjunto de datos

Para este ejemplo, usará el conjunto de datos de flores de TensorFlow:

data_root = tf.keras.utils.get_file(
  'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 5s 0us/step

La forma más sencilla de cargar estos datos en nuestro modelo es usando tf.keras.preprocessing.image.ImageDataGenerator ,

Las convenciones de TensorFlow Hub para modelos de imágenes es esperar entradas flotantes en el rango [0, 1] . Utilice el parámetro de cambio de rescale ImageDataGenerator para lograr esto.

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE)
Found 3670 images belonging to 5 classes.

El objeto resultante es un iterador que devuelve pares image_batch, label_batch .

for image_batch, label_batch in image_data:
  print("Image batch shape: ", image_batch.shape)
  print("Label batch shape: ", label_batch.shape)
  break
Image batch shape:  (32, 224, 224, 3)
Label batch shape:  (32, 5)

Ejecute el clasificador en un lote de imágenes

Ahora ejecute el clasificador en el lote de imágenes.

result_batch = classifier.predict(image_batch)
result_batch.shape
(32, 1001)
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'bee', 'barn spider', 'daisy', 'balloon', 'daisy',
       'cardoon', 'daisy', 'daisy', 'diaper', 'quill', 'wreck', 'hip',
       'daisy', 'vase', 'daisy', 'daisy', 'daisy', 'cardoon', 'daisy',
       'sea urchin', 'picket fence', 'daisy', 'strawberry',
       'coral fungus', 'picket fence', 'quill', 'daisy', 'daisy', 'pot',
       'sarong', 'hair slide'], dtype='<U30')

Ahora compruebe cómo estas predicciones se alinean con las imágenes:

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

png

Consulte el archivo LICENSE.txt para conocer las atribuciones de imágenes.

Los resultados están lejos de ser perfectos, pero razonables considerando que estas no son las clases para las que se entrenó el modelo (excepto "margarita").

Descarga el modelo sin cabeza

TensorFlow Hub también distribuye modelos sin la capa de clasificación superior. Estos se pueden utilizar para transferir fácilmente el aprendizaje.

Cualquier modelo de vector de función de imagen compatible de tfhub.dev funcionará aquí.

feature_extractor_model = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4" 

Cree el extractor de características. Use trainable=False para congelar las variables en la capa del extractor de entidades, de modo que el entrenamiento solo modifique la nueva capa del clasificador.

feature_extractor_layer = hub.KerasLayer(
    feature_extractor_model, input_shape=(224, 224, 3), trainable=False)

Devuelve un vector de 1280 de longitud para cada imagen:

feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)

Adjuntar un cabezal de clasificación

Ahora envuelva la capa central en un modelo tf.keras.Sequential y agregue una nueva capa de clasificación.

model = tf.keras.Sequential([
  feature_extractor_layer,
  tf.keras.layers.Dense(image_data.num_classes)
])

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer_1 (KerasLayer)   (None, 1280)              2257984   
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])

Entrena el modelo

Utilice compilar para configurar el proceso de entrenamiento:

model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

Ahora use el método .fit para entrenar el modelo.

Para mantener este ejemplo, entrene solo 2 épocas. Para visualizar el progreso del entrenamiento, use una devolución de llamada personalizada para registrar la pérdida y la precisión de cada lote individualmente, en lugar del promedio de la época.

class CollectBatchStats(tf.keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):
    self.batch_losses.append(logs['loss'])
    self.batch_acc.append(logs['acc'])
    self.model.reset_metrics()
steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)

batch_stats_callback = CollectBatchStats()

history = model.fit(image_data, epochs=2,
                    steps_per_epoch=steps_per_epoch,
                    callbacks=[batch_stats_callback])
Epoch 1/2
115/115 [==============================] - 12s 100ms/step - loss: 0.4156 - acc: 0.8750
Epoch 2/2
115/115 [==============================] - 12s 101ms/step - loss: 0.1960 - acc: 0.9375

Ahora, después de unas pocas iteraciones de entrenamiento, ya podemos ver que el modelo está progresando en la tarea.

plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats_callback.batch_losses)
[<matplotlib.lines.Line2D at 0x7f3ac9d00b38>]

png

plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats_callback.batch_acc)
[<matplotlib.lines.Line2D at 0x7f3ac49b3e48>]

png

Ver las predicciones

Para rehacer la trama de antes, primero obtenga la lista ordenada de nombres de clases:

class_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
class_names
array(['Daisy', 'Dandelion', 'Roses', 'Sunflowers', 'Tulips'],
      dtype='<U10')

Ejecute el lote de imágenes a través del modelo y convierta los índices en nombres de clases.

predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

Grafica el resultado

label_id = np.argmax(label_batch, axis=-1)
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  color = "green" if predicted_id[n] == label_id[n] else "red"
  plt.title(predicted_label_batch[n].title(), color=color)
  plt.axis('off')
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")

png

Exporta tu modelo

Ahora que ha entrenado el modelo, expórtelo como modelo guardado para usarlo más adelante.

t = time.time()

export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path)

export_path
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/saved_models/1601602065/assets

INFO:tensorflow:Assets written to: /tmp/saved_models/1601602065/assets

'/tmp/saved_models/1601602065'

Ahora confirme que podemos volver a cargarlo, y todavía da los mismos resultados:

reloaded = tf.keras.models.load_model(export_path)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
0.0

Este modelo guardado se puede cargar para inferencia posterior o convertir a TFLite o TFjs .

Aprende más

Consulte más tutoriales sobre el uso de modelos de imágenes de TensorFlow Hub.