![]() |
![]() |
![]() |
![]() |
![]() |
TensorFlow Hub is a repository of pre-trained TensorFlow models.
This tutorial demonstrates how to:
- Use models from TensorFlow Hub with
tf.keras
- Use an image classification model from TensorFlow Hub
- Do simple transfer learning to fine-tune a model for your own image classes
Setup
import numpy as np
import time
import PIL.Image as Image
import matplotlib.pylab as plt
import tensorflow as tf
import tensorflow_hub as hub
An ImageNet classifier
You'll start by using a pretrained classifer model to take an image and predict what it's an image of - no training required!
Download the classifier
Use hub.KerasLayer
to load a MobileNetV2 model from TensorFlow Hub. Any compatible image classifier model from tfhub.dev will work here.
classifier_model ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
IMAGE_SHAPE = (224, 224)
classifier = tf.keras.Sequential([
hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))
])
Run it on a single image
Download a single image to try the model on.
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg 65536/61306 [================================] - 0s 0us/step
grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
(224, 224, 3)
Add a batch dimension, and pass the image to the model.
result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)
The result is a 1001 element vector of logits, rating the probability of each class for the image.
So the top class ID can be found with argmax:
predicted_class = np.argmax(result[0], axis=-1)
predicted_class
653
Decode the predictions
Take the predicted class ID and fetch the ImageNet
labels to decode the predictions
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt 16384/10484 [==============================================] - 0s 0us/step
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())
Simple transfer learning
But what if you want to train a classifier for a dataset with different classes? You can also use a model from TFHub to train a custom image classier by retraining the top layer of the model to recognize the classes in our dataset.
Dataset
For this example you will use the TensorFlow flowers dataset:
data_root = tf.keras.utils.get_file(
'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 228818944/228813984 [==============================] - 7s 0us/step
Let's load this data into our model using images off disk using image_dataset_from_directory.
batch_size = 32
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
str(data_root),
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
Found 3670 files belonging to 5 classes. Using 2936 files for training.
The flowers dataset has five classes.
class_names = np.array(train_ds.class_names)
print(class_names)
['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
TensorFlow Hub's conventions for image models is to expect float inputs in the [0, 1]
range. Use the Rescaling
layer to achieve this.
normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
Let's make sure to use buffered prefetching so we can yield data from disk without having I/O become blocking. These are two important methods you should use when loading data.
Interested readers can learn more about both methods, as well as how to cache data to disk in the data performance guide.
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
(32, 224, 224, 3) (32,)
Run the classifier on a batch of images
Now run the classifier on the image batch.
result_batch = classifier.predict(train_ds)
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'coral fungus', 'rapeseed', ..., 'daisy', 'daisy', 'birdhouse'], dtype='<U30')
Now check how these predictions line up with the images:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(predicted_class_names[n])
plt.axis('off')
_ = plt.suptitle("ImageNet predictions")
See the LICENSE.txt
file for image attributions.
The results are far from perfect, but reasonable considering that these are not the classes the model was trained for (except "daisy").
Download the headless model
TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.
Any compatible image feature vector model from tfhub.dev will work here.
feature_extractor_model = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
Create the feature extractor. Use trainable=False
to freeze the variables in the feature extractor layer, so that the training only modifies the new classifier layer.
feature_extractor_layer = hub.KerasLayer(
feature_extractor_model, input_shape=(224, 224, 3), trainable=False)
It returns a 1280-length vector for each image:
feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)
Attach a classification head
Now wrap the hub layer in a tf.keras.Sequential
model, and add a new classification layer.
num_classes = len(class_names)
model = tf.keras.Sequential([
feature_extractor_layer,
tf.keras.layers.Dense(num_classes)
])
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= keras_layer_1 (KerasLayer) (None, 1280) 2257984 _________________________________________________________________ dense (Dense) (None, 5) 6405 ================================================================= Total params: 2,264,389 Trainable params: 6,405 Non-trainable params: 2,257,984 _________________________________________________________________
predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])
Train the model
Use compile to configure the training process:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['acc'])
Now use the .fit
method to train the model.
To keep this example short train just 2 epochs. To visualize the training progress, use a custom callback to log the loss and accuracy of each batch individually, instead of the epoch average.
class CollectBatchStats(tf.keras.callbacks.Callback):
def __init__(self):
self.batch_losses = []
self.batch_acc = []
def on_train_batch_end(self, batch, logs=None):
self.batch_losses.append(logs['loss'])
self.batch_acc.append(logs['acc'])
self.model.reset_metrics()
batch_stats_callback = CollectBatchStats()
history = model.fit(train_ds, epochs=2,
callbacks=[batch_stats_callback])
Epoch 1/2 92/92 [==============================] - 5s 18ms/step - loss: 0.7497 - acc: 0.7288 Epoch 2/2 92/92 [==============================] - 2s 18ms/step - loss: 0.3691 - acc: 0.8730
Now after, even just a few training iterations, we can already see that the model is making progress on the task.
plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats_callback.batch_losses)
[<matplotlib.lines.Line2D at 0x7fc6f4ff3748>]
plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats_callback.batch_acc)
[<matplotlib.lines.Line2D at 0x7fc6c8351da0>]
Check the predictions
To redo the plot from before, first get the ordered list of class names:
predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]
Plot the result
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(predicted_label_batch[n].title())
plt.axis('off')
_ = plt.suptitle("Model predictions")
Export your model
Now that you've trained the model, export it as a SavedModel for use later on.
t = time.time()
export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path)
export_path
INFO:tensorflow:Assets written to: /tmp/saved_models/1611196143/assets INFO:tensorflow:Assets written to: /tmp/saved_models/1611196143/assets '/tmp/saved_models/1611196143'
Now confirm that we can reload it, and it still gives the same results:
reloaded = tf.keras.models.load_model(export_path)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
0.0
This SavedModel can be loaded for inference later, or converted to TFLite or TFjs.
Learn more
Check out more tutorials for using image models from TensorFlow Hub.