Cette page a été traduite par l'API Cloud Translation.
Switch to English

Recyclage d'un classificateur d'images

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher sur GitHub Télécharger le carnet Voir le modèle TF Hub

introduction

Les modèles de classification d'images ont des millions de paramètres. Les former à partir de zéro nécessite beaucoup de données de formation étiquetées et une grande puissance de calcul. L'apprentissage par transfert est une technique qui raccourcit une grande partie de cela en prenant un morceau d'un modèle qui a déjà été formé sur une tâche associée et en le réutilisant dans un nouveau modèle.

Ce Colab montre comment créer un modèle Keras pour classer cinq espèces de fleurs en utilisant un TF2 SavedModel pré-formé de TensorFlow Hub pour l'extraction d'entités d'image, formé sur l'ensemble de données ImageNet beaucoup plus grand et plus général. En option, l'extracteur de caractéristiques peut être entraîné ("affiné") à côté du classificateur nouvellement ajouté.

Vous cherchez plutôt un outil?

Ceci est un didacticiel de codage TensorFlow. Si vous voulez un outil qui construit simplement le modèle TensorFlow ou TF Lite pour, jetez un œil à l'outil de ligne de commande make_image_classifier qui est installé par le package PIP tensorflow-hub[make_image_classifier] , ou à ce colab TF Lite.

Installer

import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")
TF version: 2.3.1
Hub version: 0.9.0
WARNING:tensorflow:From <ipython-input-1-0831fa394ed3>:12: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
GPU is available

Sélectionnez le module TF2 SavedModel à utiliser

Pour commencer, utilisez https: //tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4 . La même URL peut être utilisée dans le code pour identifier le SavedModel et dans votre navigateur pour afficher sa documentation. (Notez que les modèles au format TF1 Hub ne fonctionneront pas ici.)

Vous pouvez trouver plus de modèles TF2 qui génèrent des vecteurs de caractéristiques d'image ici .

module_selection = ("mobilenet_v2_100_224", 224) 
handle_base, pixels = module_selection
MODULE_HANDLE ="https://tfhub.dev/google/imagenet/{}/feature_vector/4".format(handle_base)
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

BATCH_SIZE = 32 
Using https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4 with input size (224, 224)

Configurer le jeu de données Flowers

Les entrées sont redimensionnées de manière appropriée pour le module sélectionné. L'augmentation de l'ensemble de données (c'est-à-dire les distorsions aléatoires d'une image à chaque fois qu'elle est lue) améliore la formation, en particulier. lors du réglage fin.

data_dir = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 3s 0us/step

datagen_kwargs = dict(rescale=1./255, validation_split=.20)
dataflow_kwargs = dict(target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
                   interpolation="bilinear")

valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    **datagen_kwargs)
valid_generator = valid_datagen.flow_from_directory(
    data_dir, subset="validation", shuffle=False, **dataflow_kwargs)

do_data_augmentation = False 
if do_data_augmentation:
  train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
      rotation_range=40,
      horizontal_flip=True,
      width_shift_range=0.2, height_shift_range=0.2,
      shear_range=0.2, zoom_range=0.2,
      **datagen_kwargs)
else:
  train_datagen = valid_datagen
train_generator = train_datagen.flow_from_directory(
    data_dir, subset="training", shuffle=True, **dataflow_kwargs)
Found 731 images belonging to 5 classes.
Found 2939 images belonging to 5 classes.

Définition du modèle

Il suffit de mettre un classificateur linéaire au-dessus de feature_extractor_layer avec le module Hub.

Pour la vitesse, nous commençons avec un feature_extractor_layer non entraînable, mais vous pouvez également activer le réglage fin pour une plus grande précision.

do_fine_tuning = False 
print("Building model with", MODULE_HANDLE)
model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(MODULE_HANDLE, trainable=do_fine_tuning),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(train_generator.num_classes,
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()
Building model with https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 1280)              2257984   
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

Entraîner le modèle

model.compile(
  optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  metrics=['accuracy'])
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
hist = model.fit(
    train_generator,
    epochs=5, steps_per_epoch=steps_per_epoch,
    validation_data=valid_generator,
    validation_steps=validation_steps).history
Epoch 1/5
91/91 [==============================] - 16s 171ms/step - loss: 0.9733 - accuracy: 0.7286 - val_loss: 0.7463 - val_accuracy: 0.8395
Epoch 2/5
91/91 [==============================] - 15s 160ms/step - loss: 0.6944 - accuracy: 0.8762 - val_loss: 0.7031 - val_accuracy: 0.8665
Epoch 3/5
91/91 [==============================] - 14s 159ms/step - loss: 0.6602 - accuracy: 0.8934 - val_loss: 0.7243 - val_accuracy: 0.8580
Epoch 4/5
91/91 [==============================] - 15s 162ms/step - loss: 0.6276 - accuracy: 0.9143 - val_loss: 0.7081 - val_accuracy: 0.8665
Epoch 5/5
91/91 [==============================] - 15s 160ms/step - loss: 0.6115 - accuracy: 0.9254 - val_loss: 0.6772 - val_accuracy: 0.8864

plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"])
plt.plot(hist["val_loss"])

plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])
[<matplotlib.lines.Line2D at 0x7fb5f7ee4978>]

png

png

Essayez le modèle sur une image à partir des données de validation:

def get_class_string_from_index(index):
   for class_string, class_index in valid_generator.class_indices.items():
      if class_index == index:
         return class_string

x, y = next(valid_generator)
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
plt.axis('off')
plt.show()

# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(predicted_index))

png

True label: daisy
Predicted label: daisy

Enfin, le modèle entraîné peut être enregistré pour un déploiement sur TF Serving ou TF Lite (sur mobile) comme suit.

saved_model_path = "/tmp/saved_flowers_model"
tf.saved_model.save(model, saved_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/saved_flowers_model/assets

INFO:tensorflow:Assets written to: /tmp/saved_flowers_model/assets

Facultatif: déploiement sur TensorFlow Lite

TensorFlow Lite vous permet de déployer des modèles TensorFlow sur des appareils mobiles et IoT. Le code ci-dessous montre comment convertir le modèle entraîné en TF Lite et appliquer les outils de post-formation du kit d'outils d'optimisation de modèle TensorFlow . Enfin, il l'exécute dans l'interpréteur TF Lite pour examiner la qualité résultante

  • La conversion sans optimisation fournit les mêmes résultats qu'auparavant (jusqu'à l'erreur d'arrondi).
  • La conversion avec optimisation sans aucune donnée quantifie les poids du modèle à 8 bits, mais l'inférence utilise toujours le calcul en virgule flottante pour les activations du réseau neuronal. Cela réduit la taille du modèle presque par un facteur de 4 et améliore la latence du processeur sur les appareils mobiles.
  • De plus, le calcul des activations du réseau neuronal peut également être quantifié en nombres entiers de 8 bits si un petit ensemble de données de référence est fourni pour calibrer la plage de quantification. Sur un appareil mobile, cela accélère davantage l'inférence et permet de fonctionner sur des accélérateurs comme EdgeTPU.

# TODO(b/156102192)
optimize_lite_model = False  

num_calibration_examples = 60  
representative_dataset = None
if optimize_lite_model and num_calibration_examples:
  # Use a bounded number of training examples without labels for calibration.
  # TFLiteConverter expects a list of input tensors, each with batch size 1.
  representative_dataset = lambda: itertools.islice(
      ([image[None, ...]] for batch, _ in train_generator for image in batch),
      num_calibration_examples)

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
if optimize_lite_model:
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
  if representative_dataset:  # This is optional, see above.
    converter.representative_dataset = representative_dataset
lite_model_content = converter.convert()

with open("/tmp/lite_flowers_model", "wb") as f:
  f.write(lite_model_content)
print("Wrote %sTFLite model of %d bytes." %
      ("optimized " if optimize_lite_model else "", len(lite_model_content)))
interpreter = tf.lite.Interpreter(model_content=lite_model_content)
# This little helper wraps the TF Lite interpreter as a numpy-to-numpy function.
def lite_model(images):
  interpreter.allocate_tensors()
  interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
  interpreter.invoke()
  return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])

num_eval_examples = 50  
eval_dataset = ((image, label)  # TFLite expects batch size 1.
                for batch in train_generator
                for (image, label) in zip(*batch))
count = 0
count_lite_tf_agree = 0
count_lite_correct = 0
for image, label in eval_dataset:
  probs_lite = lite_model(image[None, ...])[0]
  probs_tf = model(image[None, ...]).numpy()[0]
  y_lite = np.argmax(probs_lite)
  y_tf = np.argmax(probs_tf)
  y_true = np.argmax(label)
  count +=1
  if y_lite == y_tf: count_lite_tf_agree += 1
  if y_lite == y_true: count_lite_correct += 1
  if count >= num_eval_examples: break
print("TF Lite model agrees with original model on %d of %d examples (%g%%)." %
      (count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))
print("TF Lite model is accurate on %d of %d examples (%g%%)." %
      (count_lite_correct, count, 100.0 * count_lite_correct / count))