Передача обучения с помощью TensorFlow Hub

Посмотреть на TensorFlow.org Запустить в Google Colab Посмотреть на GitHub Скачать блокнот См. Модель TF Hub

TensorFlow Hub - это репозиторий предварительно обученных моделей TensorFlow.

В этом руководстве показано, как:

  1. Используйте модели из TensorFlow Hub с tf.keras
  2. Используйте модель классификации изображений из TensorFlow Hub
  3. Выполните простое переносное обучение, чтобы точно настроить модель для ваших собственных классов изображений.

Настраивать

import numpy as np
import time

import PIL.Image as Image
import matplotlib.pylab as plt

import tensorflow as tf
import tensorflow_hub as hub

Классификатор ImageNet

Вы начнете с использования предварительно обученной модели классификатора, чтобы взять изображение и предсказать, что это за изображение - никакого обучения не требуется!

Скачать классификатор

Используйте hub.KerasLayer чтобы загрузить модель MobileNetV2 из TensorFlow Hub. Здесь будет работать любая совместимая модель классификатора изображений от TensorFlow Hub.

IMAGE_SHAPE = (224, 224)

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))
])

Запустите его на одном изображении

Загрузите одно изображение, чтобы примерить модель.

grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg
65536/61306 [================================] - 0s 0us/step

PNG

grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
(224, 224, 3)

Добавьте размер партии и передайте изображение модели.

result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)

Результатом является 1001-элементный вектор логитов, оценивающий вероятность каждого класса для изображения.

Итак, идентификатор высшего класса можно найти с помощью argmax:

predicted_class = np.argmax(result[0], axis=-1)
predicted_class
653

Расшифруйте прогнозы

Возьмите прогнозируемый идентификатор класса и ImageNet метки ImageNet для декодирования прогнозов.

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
16384/10484 [==============================================] - 0s 0us/step
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())

PNG

Простое трансферное обучение

Но что, если вы хотите обучить классификатор для набора данных с разными классами? Вы также можете использовать модель из TFHub для обучения пользовательского классификатора изображений, повторно обучив верхний уровень модели распознавать классы в нашем наборе данных.

Набор данных

В этом примере вы будете использовать набор данных цветов TensorFlow:

data_root = tf.keras.utils.get_file(
  'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 7s 0us/step

Давайте загрузим эти данные в нашу модель, используя изображения с диска, используя image_dataset_from_directory.

batch_size = 32
img_height = 224
img_width = 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  str(data_root),
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.

Набор данных цветов состоит из пяти классов.

class_names = np.array(train_ds.class_names)
print(class_names)
['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']

Соглашения TensorFlow Hub для моделей изображений - ожидать входные данные с плавающей запятой в диапазоне [0, 1] . Для этого используйте слой Rescaling .

normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))

Обязательно используйте буферизованную предварительную выборку, чтобы мы могли передавать данные с диска без блокировки ввода-вывода. Это два важных метода, которые вы должны использовать при загрузке данных.

Заинтересованные читатели могут узнать больше об обоих методах, а также о том, как кэшировать данные на диск, в руководстве по производительности данных .

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break
(32, 224, 224, 3)
(32,)

Запустить классификатор для пакета изображений

Теперь запустите классификатор для пакета изображений.

result_batch = classifier.predict(train_ds)
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'coral fungus', 'rapeseed', ..., 'daisy', 'daisy',
       'birdhouse'], dtype='<U30')

Теперь проверьте, как эти прогнозы совпадают с изображениями:

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

PNG

См. LICENSE.txt для атрибуции изображений.

Результаты далеки от идеальных, но разумные, учитывая, что это не те классы, для которых модель была обучена (кроме «ромашки»).

Скачать модель без головы

TensorFlow Hub также распространяет модели без верхнего уровня классификации. Их можно использовать для простого переноса обучения.

Здесь будет работать любая совместимая векторная модель функции изображения из TensorFlow Hub.

feature_extractor_model = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"

Создайте экстрактор признаков. Используйте trainable=False чтобы заморозить переменные в слое экстрактора признаков, чтобы при обучении изменялся только новый слой классификатора.

feature_extractor_layer = hub.KerasLayer(
    feature_extractor_model, input_shape=(224, 224, 3), trainable=False)

Он возвращает вектор длиной 1280 для каждого изображения:

feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)

Прикрепите классификационную головку

Теперь оберните слой хаба в модель tf.keras.Sequential и добавьте новый слой классификации.

num_classes = len(class_names)

model = tf.keras.Sequential([
  feature_extractor_layer,
  tf.keras.layers.Dense(num_classes)
])

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer_1 (KerasLayer)   (None, 1280)              2257984   
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])

Обучите модель

Используйте компиляцию для настройки тренировочного процесса:

model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

Теперь используйте метод .fit для обучения модели.

Чтобы сохранить этот пример коротким, поездом всего 2 эпохи. Чтобы визуализировать прогресс обучения, используйте настраиваемый обратный вызов, чтобы регистрировать потери и точность каждого пакета индивидуально, а не среднее значение за эпоху.

class CollectBatchStats(tf.keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):
    self.batch_losses.append(logs['loss'])
    self.batch_acc.append(logs['acc'])
    self.model.reset_metrics()

batch_stats_callback = CollectBatchStats()

history = model.fit(train_ds, epochs=2,
                    callbacks=[batch_stats_callback])
Epoch 1/2
92/92 [==============================] - 4s 18ms/step - loss: 0.6088 - acc: 0.7917
Epoch 2/2
92/92 [==============================] - 2s 18ms/step - loss: 0.3655 - acc: 0.8333

Теперь, даже после нескольких итераций обучения, мы уже можем видеть, что модель успешно справляется с задачей.

plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats_callback.batch_losses)
[<matplotlib.lines.Line2D at 0x7ff00cc5c750>]

PNG

plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats_callback.batch_acc)
[<matplotlib.lines.Line2D at 0x7ff00cb830d0>]

PNG

Проверить прогнозы

Чтобы повторить сюжет из предыдущего, сначала получите упорядоченный список имен классов:

predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

Постройте результат

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_label_batch[n].title())
  plt.axis('off')
_ = plt.suptitle("Model predictions")

PNG

Экспортируйте вашу модель

Теперь, когда вы обучили модель, экспортируйте ее как SavedModel для использования в дальнейшем.

t = time.time()

export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path)

export_path
INFO:tensorflow:Assets written to: /tmp/saved_models/1624325018/assets
INFO:tensorflow:Assets written to: /tmp/saved_models/1624325018/assets
'/tmp/saved_models/1624325018'

Теперь подтвердите, что мы можем перезагрузить его, и он по-прежнему дает те же результаты:

reloaded = tf.keras.models.load_model(export_path)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
0.0

Эта SavedModel может быть загружена для последующего вывода или преобразована в TFLite или TFjs .

Учить больше

Ознакомьтесь с дополнительными руководствами по использованию моделей изображений в TensorFlow Hub.