דף זה תורגם על ידי Cloud Translation API.
Switch to English

הכשרה מחודשת של מסווג תמונות

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה ב- GitHub הורד מחברת ראה מודל TF Hub

מבוא

למודלים של סיווג תמונות יש מיליוני פרמטרים. אימון אותם מאפס דורש הרבה נתוני אימונים שכותרתו והרבה כוח מחשוב. למידת העברה היא טכניקה שקוצרת דרך הרבה מכך על ידי לקיחת פיסת מודל שכבר הוכשרה למשימה קשורה ושימוש חוזר במודל חדש.

Colab זה מדגים כיצד לבנות מודל Keras למיון חמישה מינים של פרחים על ידי שימוש ב- TF2 SavedModel שהוכשר מראש מ- TensorFlow Hub לצורך מיצוי תכונות תמונה, מאומן במערך ImageNet הכללי הרבה יותר וכללי. לחלופין, ניתן להכשיר את מחלץ התכונות ("מכוונן") לצד המסווג החדש.

מחפש כלי במקום זאת?

זהו מדריך קידוד של TensorFlow. אם אתה רוצה כלי פשוט בונה את מודל לייט TensorFlow או TF עבור, תסתכל על make_image_classifier כלי שורת הפקודה שמקבל מותקן על ידי חבילת PIP tensorflow-hub[make_image_classifier] , או על זה colab TF לייט.

להכין

import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")
TF version: 2.3.1
Hub version: 0.10.0
WARNING:tensorflow:From <ipython-input-1-0831fa394ed3>:12: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
GPU is available

בחר במודול TF2 SavedModel לשימוש

בתור התחלה, השתמש ב- https: //tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4 . ניתן להשתמש באותה URL בקוד לזיהוי ה- SavedModel ובדפדפן שלך כדי להציג את התיעוד שלו. (שים לב כי דגמים בפורמט TF1 Hub לא יפעלו כאן.)

תוכלו למצוא דגמי TF2 נוספים המייצרים וקטורי תכונות תמונה כאן .

module_selection = ("mobilenet_v2_100_224", 224)
handle_base, pixels = module_selection
MODULE_HANDLE ="https://tfhub.dev/google/imagenet/{}/feature_vector/4".format(handle_base)
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

BATCH_SIZE = 32
Using https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4 with input size (224, 224)

הגדר את מערך הנתונים של פרחים

גודל הקלטים מתאים למודול שנבחר. הגדלת מערך נתונים (כלומר, עיוותים אקראיים של תמונה בכל קריאה) משפרת את האימון, במיוחד. בעת כוונון עדין.

data_dir = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 1s 0us/step

datagen_kwargs = dict(rescale=1./255, validation_split=.20)
dataflow_kwargs = dict(target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
                   interpolation="bilinear")

valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    **datagen_kwargs)
valid_generator = valid_datagen.flow_from_directory(
    data_dir, subset="validation", shuffle=False, **dataflow_kwargs)

do_data_augmentation = False
if do_data_augmentation:
  train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
      rotation_range=40,
      horizontal_flip=True,
      width_shift_range=0.2, height_shift_range=0.2,
      shear_range=0.2, zoom_range=0.2,
      **datagen_kwargs)
else:
  train_datagen = valid_datagen
train_generator = train_datagen.flow_from_directory(
    data_dir, subset="training", shuffle=True, **dataflow_kwargs)
Found 731 images belonging to 5 classes.
Found 2939 images belonging to 5 classes.

הגדרת המודל

כל מה שנדרש הוא לשים מסווג ליניארי על גבי feature_extractor_layer עם מודול הרכזת.

למהירות, אנו מתחילים עם feature_extractor_layer שאינה feature_extractor_layer לאימון, אך ניתן גם לאפשר כיוונון עדין לדיוק רב יותר.

do_fine_tuning = False
print("Building model with", MODULE_HANDLE)
model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(MODULE_HANDLE, trainable=do_fine_tuning),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(train_generator.num_classes,
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()
Building model with https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 1280)              2257984   
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

הכשרת המודל

model.compile(
  optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  metrics=['accuracy'])
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
hist = model.fit(
    train_generator,
    epochs=5, steps_per_epoch=steps_per_epoch,
    validation_data=valid_generator,
    validation_steps=validation_steps).history
Epoch 1/5
91/91 [==============================] - 14s 151ms/step - loss: 0.9491 - accuracy: 0.7372 - val_loss: 0.7738 - val_accuracy: 0.8239
Epoch 2/5
91/91 [==============================] - 13s 143ms/step - loss: 0.7182 - accuracy: 0.8662 - val_loss: 0.7188 - val_accuracy: 0.8551
Epoch 3/5
91/91 [==============================] - 13s 142ms/step - loss: 0.6436 - accuracy: 0.8985 - val_loss: 0.6873 - val_accuracy: 0.8679
Epoch 4/5
91/91 [==============================] - 13s 141ms/step - loss: 0.6175 - accuracy: 0.9209 - val_loss: 0.6964 - val_accuracy: 0.8594
Epoch 5/5
91/91 [==============================] - 13s 142ms/step - loss: 0.6252 - accuracy: 0.9116 - val_loss: 0.6893 - val_accuracy: 0.8693

plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"])
plt.plot(hist["val_loss"])

plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])
[<matplotlib.lines.Line2D at 0x7f6cd967ce10>]

png

png

נסה את המודל על תמונה מנתוני האימות:

def get_class_string_from_index(index):
   for class_string, class_index in valid_generator.class_indices.items():
      if class_index == index:
         return class_string

x, y = next(valid_generator)
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
plt.axis('off')
plt.show()

# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(predicted_index))

png

True label: daisy
Predicted label: daisy

לבסוף, ניתן לשמור את המודל המאומן לפריסה ל- TF Serving או ל- TF Lite (בנייד) כדלקמן.

saved_model_path = "/tmp/saved_flowers_model"
tf.saved_model.save(model, saved_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/saved_flowers_model/assets

INFO:tensorflow:Assets written to: /tmp/saved_flowers_model/assets

אופציונלי: פריסה ל- TensorFlow Lite

TensorFlow Lite מאפשר לך לפרוס דגמי TensorFlow למכשירים ניידים ו- IoT. הקוד שלהלן מראה כיצד להמיר את המודל המאומן ל- TF Lite ולהחיל כלים לאחר אימון מתוך ערכת הכלים למיטוב המודל TensorFlow . לבסוף, הוא מריץ אותו במתורגמן TF Lite כדי לבחון את האיכות המתקבלת

  • המרה ללא אופטימיזציה מספקת את אותן התוצאות כמו בעבר (עד שגיאת עגול).
  • המרה באופטימיזציה ללא נתונים מכמתת את משקל המודל ל -8 סיביות, אך ההסקה עדיין משתמשת בחישוב נקודות צפות לצורך הפעלת רשת עצבית. זה מקטין את גודל הדגם כמעט בגורם 4 ומשפר את חביון המעבד במכשירים ניידים.
  • מלמעלה, ניתן לכמת חישוב של פעולות הרשת העצבית לכדי מספרים שלמים של 8 סיביות גם אם נקבע מערך הפניה קטן בכיול טווח הכימות. במכשיר נייד זה מאיץ את ההסקה עוד יותר ומאפשר לרוץ על מאיצים כמו EdgeTPU.

הגדרות אופטימיזציה

interpreter = tf.lite.Interpreter(model_content=lite_model_content)
# This little helper wraps the TF Lite interpreter as a numpy-to-numpy function.
def lite_model(images):
  interpreter.allocate_tensors()
  interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
  interpreter.invoke()
  return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
num_eval_examples = 50 
eval_dataset = ((image, label)  # TFLite expects batch size 1.
                for batch in train_generator
                for (image, label) in zip(*batch))
count = 0
count_lite_tf_agree = 0
count_lite_correct = 0
for image, label in eval_dataset:
  probs_lite = lite_model(image[None, ...])[0]
  probs_tf = model(image[None, ...]).numpy()[0]
  y_lite = np.argmax(probs_lite)
  y_tf = np.argmax(probs_tf)
  y_true = np.argmax(label)
  count +=1
  if y_lite == y_tf: count_lite_tf_agree += 1
  if y_lite == y_true: count_lite_correct += 1
  if count >= num_eval_examples: break
print("TF Lite model agrees with original model on %d of %d examples (%g%%)." %
      (count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))
print("TF Lite model is accurate on %d of %d examples (%g%%)." %
      (count_lite_correct, count, 100.0 * count_lite_correct / count))