TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Transfer learning with TensorFlow Hub

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlow Hub is a way to share pretrained model components. See the TensorFlow Module Hub for a searchable listing of pre-trained models. This tutorial demonstrates:

  1. How to use TensorFlow Hub with tf.keras.
  2. How to do image classification using TensorFlow Hub.
  3. How to do simple transfer learning.

Setup

from __future__ import absolute_import, division, print_function, unicode_literals

import matplotlib.pylab as plt

import tensorflow as tf
!pip install -q -U tensorflow_hub
import tensorflow_hub as hub

from tensorflow.keras import layers

An ImageNet classifier

Download the classifier

Use hub.module to load a mobilenet, and tf.keras.layers.Lambda to wrap it up as a keras layer. Any TensorFlow 2 compatible image classifier URL from tfhub.dev will work here.

classifier_url ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2" #@param {type:"string"}
IMAGE_SHAPE = (224, 224)

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))
])

Run it on a single image

Download a single image to try the model on.

import numpy as np
import PIL.Image as Image

grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg
65536/61306 [================================] - 0s 0us/step

png

grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
(224, 224, 3)

Add a batch dimension, and pass the image to the model.

result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)

The result is a 1001 element vector of logits, rating the probability of each class for the image.

So the top class ID can be found with argmax:

predicted_class = np.argmax(result[0], axis=-1)
predicted_class
653

Decode the predictions

We have the predicted class ID, Fetch the ImageNet labels, and decode the predictions

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
16384/10484 [==============================================] - 0s 0us/step
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())

png

Simple transfer learning

Using TF Hub it is simple to retrain the top layer of the model to recognize the classes in our dataset.

Dataset

For this example you will use the TensorFlow flowers dataset:

data_root = tf.keras.utils.get_file(
  'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 4s 0us/step

The simplest way to load this data into our model is using tf.keras.preprocessing.image.ImageDataGenerator,

All of TensorFlow Hub's image modules expect float inputs in the [0, 1] range. Use the ImageDataGenerator's rescale parameter to achieve this.

The image size will be handled later.

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE)
Found 3670 images belonging to 5 classes.

The resulting object is an iterator that returns image_batch, label_batch pairs.

for image_batch, label_batch in image_data:
  print("Image batch shape: ", image_batch.shape)
  print("Label batch shape: ", label_batch.shape)
  break
Image batch shape:  (32, 224, 224, 3)
Label batch shape:  (32, 5)

Run the classifier on a batch of images

Now run the classifier on the image batch.

result_batch = classifier.predict(image_batch)
result_batch.shape
(32, 1001)
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'daisy', 'daisy', 'plastic bag', 'bee', 'cauliflower',
       'daisy', 'bee', 'cauliflower', 'vase', 'greenhouse', 'daisy',
       'lemon', 'daisy', 'daisy', 'picket fence', 'daisy', 'bee', 'nail',
       'daisy', 'bee', 'paper towel', 'hair slide', 'volcano', 'ant',
       'pot', 'daisy', 'daisy', 'daisy', 'hip', 'brassiere', 'hip'],
      dtype='<U30')

Now check how these predictions line up with the images:

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

png

See the LICENSE.txt file for image attributions.

The results are far from perfect, but reasonable considering that these are not the classes the model was trained for (except "daisy").

Download the headless model

TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.

Any Tensorflow 2 compatible image feature vector URL from tfhub.dev will work here.

feature_extractor_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2" #@param {type:"string"}

Create the feature extractor.

feature_extractor_layer = hub.KerasLayer(feature_extractor_url,
                                         input_shape=(224,224,3))

It returns a 1280-length vector for each image:

feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)

Freeze the variables in the feature extractor layer, so that the training only modifies the new classifier layer.

feature_extractor_layer.trainable = False

Attach a classification head

Now wrap the hub layer in a tf.keras.Sequential model, and add a new classification layer.

model = tf.keras.Sequential([
  feature_extractor_layer,
  layers.Dense(image_data.num_classes, activation='softmax')
])

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer_1 (KerasLayer)   (None, 1280)              2257984   
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])

Train the model

Use compile to configure the training process:

model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss='categorical_crossentropy',
  metrics=['acc'])

Now use the .fit method to train the model.

To keep this example short train just 2 epochs. To visualize the training progress, use a custom callback to log the loss and accuracy of each batch individually, instead of the epoch average.

class CollectBatchStats(tf.keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):
    self.batch_losses.append(logs['loss'])
    self.batch_acc.append(logs['acc'])
    self.model.reset_metrics()
steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)

batch_stats_callback = CollectBatchStats()

history = model.fit_generator(image_data, epochs=2,
                              steps_per_epoch=steps_per_epoch,
                              callbacks = [batch_stats_callback])
Epoch 1/2
115/115 [==============================] - 23s 201ms/step - loss: 0.6942 - acc: 0.9375
Epoch 2/2
115/115 [==============================] - 18s 155ms/step - loss: 0.3411 - acc: 0.8125

Now after, even just a few training iterations, we can already see that the model is making progress on the task.

plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats_callback.batch_losses)
[<matplotlib.lines.Line2D at 0x7f4aa4fffa20>]

png

plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats_callback.batch_acc)
[<matplotlib.lines.Line2D at 0x7f4ab1fd4d30>]

png

Check the predictions

To redo the plot from before, first get the ordered list of class names:

class_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
class_names
array(['Daisy', 'Dandelion', 'Roses', 'Sunflowers', 'Tulips'],
      dtype='<U10')

Run the image batch through the model and convert the indices to class names.

predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

Plot the result

label_id = np.argmax(label_batch, axis=-1)
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  color = "green" if predicted_id[n] == label_id[n] else "red"
  plt.title(predicted_label_batch[n].title(), color=color)
  plt.axis('off')
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")

png

Export your model

Now that you've trained the model, export it as a saved model:

import time
t = time.time()

export_path = "/tmp/saved_models/{}".format(int(t))
tf.keras.experimental.export_saved_model(model, export_path)

export_path
WARNING:tensorflow:From <ipython-input-34-47b3eb0809fe>:5: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`.

WARNING:tensorflow:From <ipython-input-34-47b3eb0809fe>:5: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Train: ['train']

INFO:tensorflow:Signatures INCLUDED in export for Train: ['train']

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

WARNING:tensorflow:Export includes no default signature!

WARNING:tensorflow:Export includes no default signature!

WARNING:tensorflow:Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'

WARNING:tensorflow:Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']

INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']

WARNING:tensorflow:Export includes no default signature!

WARNING:tensorflow:Export includes no default signature!

WARNING:tensorflow:Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'

WARNING:tensorflow:Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.kernel

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.kernel

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.bias

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.bias

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.kernel

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.kernel

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.bias

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.bias

WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.

WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

WARNING:tensorflow:Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'

WARNING:tensorflow:Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: /tmp/saved_models/1571420363/saved_model.pb

INFO:tensorflow:SavedModel written to: /tmp/saved_models/1571420363/saved_model.pb

'/tmp/saved_models/1571420363'

Now confirm that we can reload it, and it still gives the same results:

reloaded = tf.keras.experimental.load_from_saved_model(export_path, custom_objects={'KerasLayer':hub.KerasLayer})
WARNING:tensorflow:From <ipython-input-35-e6826f3d5aba>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.

WARNING:tensorflow:From <ipython-input-35-e6826f3d5aba>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.kernel

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.kernel

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.bias

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.bias

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.kernel

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.kernel

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.bias

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.bias

WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.

WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
0.0

This saved model can loaded for inference later, or converted to TFLite or TFjs.