Transfer learning with TensorFlow Hub

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook See TF Hub model

TensorFlow Hub is a repository of pre-trained TensorFlow models.

This tutorial demonstrates how to:

  1. Use models from TensorFlow Hub with tf.keras
  2. Use an image classification model from TensorFlow Hub
  3. Do simple transfer learning to fine-tune a model for your own image classes

Setup

import numpy as np
import time

import PIL.Image as Image
import matplotlib.pylab as plt

import tensorflow as tf
import tensorflow_hub as hub

import datetime

%load_ext tensorboard
2021-08-03 01:23:10.509508: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0

An ImageNet classifier

You'll start by using a classifier model pre-trained on the ImageNet benchmark dataset—no initial training required!

Download the classifier

Select a MobileNetV2 pre-trained model from TensorFlow Hub and wrap it as a Keras layer with hub.KerasLayer. Any compatible image classifier model from TensorFlow Hub will work here, including the examples provided in the drop-down below.

mobilenet_v2 ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
inception_v3 = "https://tfhub.dev/google/imagenet/inception_v3/classification/5"

classifier_model = mobilenet_v2
IMAGE_SHAPE = (224, 224)

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))
])
2021-08-03 01:23:14.550099: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-08-03 01:23:15.169628: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.170488: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-08-03 01:23:15.170521: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-08-03 01:23:15.173780: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-08-03 01:23:15.173871: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2021-08-03 01:23:15.174925: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcufft.so.10
2021-08-03 01:23:15.175267: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcurand.so.10
2021-08-03 01:23:15.176251: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11
2021-08-03 01:23:15.177111: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusparse.so.11
2021-08-03 01:23:15.177279: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2021-08-03 01:23:15.177374: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.178279: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.179082: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-08-03 01:23:15.179530: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-03 01:23:15.180073: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.180889: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-08-03 01:23:15.180974: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.181794: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.182585: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-08-03 01:23:15.182625: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-08-03 01:23:15.749767: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-08-03 01:23:15.749801: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264]      0 
2021-08-03 01:23:15.749808: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0:   N 
2021-08-03 01:23:15.750018: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.750881: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.751734: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-03 01:23:15.752541: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 14646 MB memory) -> physical GPU (device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0)

Run it on a single image

Download a single image to try the model on:

grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg
65536/61306 [================================] - 0s 0us/step

png

grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
(224, 224, 3)

Add a batch dimension (with np.newaxis) and pass the image to the model:

result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
2021-08-03 01:23:17.735356: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-08-03 01:23:17.735746: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2000179999 Hz
2021-08-03 01:23:18.107465: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2021-08-03 01:23:19.987670: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8100
2021-08-03 01:23:24.867133: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-08-03 01:23:25.217810: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
(1, 1001)

The result is a 1001-element vector of logits, rating the probability of each class for the image.

The top class ID can be found with tf.math.argmax:

predicted_class = tf.math.argmax(result[0], axis=-1)
predicted_class
<tf.Tensor: shape=(), dtype=int64, numpy=653>

Decode the predictions

Take the predicted_class ID (such as 653) and fetch the ImageNet dataset labels to decode the predictions:

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
16384/10484 [==============================================] - 0s 0us/step
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())

png

Simple transfer learning

But what if you want to create a custom classifier using your own dataset that has classes that aren't included in the original ImageNet dataset (that the pre-trained model was trained on)?

To do that, you can:

  1. Select a pre-trained model from TensorFlow Hub; and
  2. Retrain the top (last) layer to recognize the classes from your custom dataset.

Dataset

In this example, you will use the TensorFlow flowers dataset:

data_root = tf.keras.utils.get_file(
  'flower_photos',
  'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 1s 0us/step

First, load this data into the model using the image data off disk with tf.keras.preprocessing.image_dataset_from_directory, which will generate a tf.data.Dataset:

batch_size = 32
img_height = 224
img_width = 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  str(data_root),
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  str(data_root),
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size
)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.

The flowers dataset has five classes:

class_names = np.array(train_ds.class_names)
print(class_names)
['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']

Second, because TensorFlow Hub's convention for image models is to expect float inputs in the [0, 1] range, use the tf.keras.layers.experimental.preprocessing.Rescaling layer to achieve this.

normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.

Third, finish the input pipeline by using buffered prefetching with Dataset.prefetch, so you can yield the data from disk without I/O blocking issues.

These are some of the most important tf.data methods you should use when loading data. Interested readers can learn more about them, as well as how to cache data to disk and other techniques, in the Better performance with the tf.data API guide.

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break
(32, 224, 224, 3)
(32,)
2021-08-03 01:23:29.346048: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Run the classifier on a batch of images

Now, run the classifier on an image batch:

result_batch = classifier.predict(train_ds)
predicted_class_names = imagenet_labels[tf.math.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'coral fungus', 'rapeseed', ..., 'daisy', 'daisy',
       'birdhouse'], dtype='<U30')

Check how these predictions line up with the images:

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

png

The results are far from perfect, but reasonable considering that these are not the classes the model was trained for (except for "daisy").

Download the headless model

TensorFlow Hub also distributes models without the top classification layer. These can be used to easily perform transfer learning.

Select a MobileNetV2 pre-trained model from TensorFlow Hub. Any compatible image feature vector model from TensorFlow Hub will work here, including the examples from the drop-down menu.

mobilenet_v2 = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
inception_v3 = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"

feature_extractor_model = mobilenet_v2

Create the feature extractor by wrapping the pre-trained model as a Keras layer with hub.KerasLayer. Use the trainable=False argument to freeze the variables, so that the training only modifies the new classifier layer:

feature_extractor_layer = hub.KerasLayer(
    feature_extractor_model,
    input_shape=(224, 224, 3),
    trainable=False)

The feature extractor returns a 1280-long vector for each image (the image batch size remains at 32 in this example):

feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)

Attach a classification head

To complete the model, wrap the feature extractor layer in a tf.keras.Sequential model and add a fully-connected layer for classification:

num_classes = len(class_names)

model = tf.keras.Sequential([
  feature_extractor_layer,
  tf.keras.layers.Dense(num_classes)
])

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer_1 (KerasLayer)   (None, 1280)              2257984   
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])

Train the model

Use Model.compile to configure the training process and add a tf.keras.callbacks.TensorBoard callback to create and store logs:

model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1) # Enable histogram computation for every epoch.
2021-08-03 01:23:35.893225: I tensorflow/core/profiler/lib/profiler_session.cc:126] Profiler session initializing.
2021-08-03 01:23:35.893268: I tensorflow/core/profiler/lib/profiler_session.cc:141] Profiler session started.
2021-08-03 01:23:35.893315: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1611] Profiler found 1 GPUs
2021-08-03 01:23:35.937647: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcupti.so.11.2
2021-08-03 01:23:36.389586: I tensorflow/core/profiler/lib/profiler_session.cc:159] Profiler session tear down.
2021-08-03 01:23:36.392213: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1743] CUPTI activity buffer flushed

Now use the Model.fit method to train the model.

To keep this example short, you'll be training for just 10 epochs. To visualize the training progress in TensorBoard later, create and store logs an a TensorBoard callback.

NUM_EPOCHS = 10

history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=NUM_EPOCHS,
                    callbacks=tensorboard_callback)
Epoch 1/10
 1/92 [..............................] - ETA: 3:48 - loss: 1.9678 - acc: 0.1250
2021-08-03 01:23:39.201133: I tensorflow/core/profiler/lib/profiler_session.cc:126] Profiler session initializing.
2021-08-03 01:23:39.201182: I tensorflow/core/profiler/lib/profiler_session.cc:141] Profiler session started.
6/92 [>.............................] - ETA: 10s - loss: 1.6567 - acc: 0.3177
2021-08-03 01:23:39.433625: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2021-08-03 01:23:39.436880: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1743] CUPTI activity buffer flushed
2021-08-03 01:23:39.470778: I tensorflow/core/profiler/internal/gpu/cupti_collector.cc:673]  GpuTracer has collected 273 callback api events and 270 activity events. 
2021-08-03 01:23:39.478035: I tensorflow/core/profiler/lib/profiler_session.cc:159] Profiler session tear down.
2021-08-03 01:23:39.485705: I tensorflow/core/profiler/rpc/client/save_profile.cc:137] Creating directory: logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39
2021-08-03 01:23:39.491527: I tensorflow/core/profiler/rpc/client/save_profile.cc:143] Dumped gzipped tool data for trace.json.gz to logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39/kokoro-gcp-ubuntu-prod-1598322355.trace.json.gz
2021-08-03 01:23:39.513926: I tensorflow/core/profiler/rpc/client/save_profile.cc:137] Creating directory: logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39
2021-08-03 01:23:39.518382: I tensorflow/core/profiler/rpc/client/save_profile.cc:143] Dumped gzipped tool data for memory_profile.json.gz to logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39/kokoro-gcp-ubuntu-prod-1598322355.memory_profile.json.gz
2021-08-03 01:23:39.519067: I tensorflow/core/profiler/rpc/client/capture_profile.cc:251] Creating directory: logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39Dumped tool data for xplane.pb to logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39/kokoro-gcp-ubuntu-prod-1598322355.xplane.pb
Dumped tool data for overview_page.pb to logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39/kokoro-gcp-ubuntu-prod-1598322355.overview_page.pb
Dumped tool data for input_pipeline.pb to logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39/kokoro-gcp-ubuntu-prod-1598322355.input_pipeline.pb
Dumped tool data for tensorflow_stats.pb to logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39/kokoro-gcp-ubuntu-prod-1598322355.tensorflow_stats.pb
Dumped tool data for kernel_stats.pb to logs/fit/20210803-012335/train/plugins/profile/2021_08_03_01_23_39/kokoro-gcp-ubuntu-prod-1598322355.kernel_stats.pb
92/92 [==============================] - 6s 36ms/step - loss: 0.7512 - acc: 0.7207 - val_loss: 0.4438 - val_acc: 0.8569
Epoch 2/10
92/92 [==============================] - 2s 21ms/step - loss: 0.3692 - acc: 0.8774 - val_loss: 0.3603 - val_acc: 0.8869
Epoch 3/10
92/92 [==============================] - 2s 20ms/step - loss: 0.2902 - acc: 0.9084 - val_loss: 0.3300 - val_acc: 0.8924
Epoch 4/10
92/92 [==============================] - 2s 20ms/step - loss: 0.2424 - acc: 0.9268 - val_loss: 0.3146 - val_acc: 0.9005
Epoch 5/10
92/92 [==============================] - 2s 21ms/step - loss: 0.2083 - acc: 0.9441 - val_loss: 0.3056 - val_acc: 0.8992
Epoch 6/10
92/92 [==============================] - 2s 21ms/step - loss: 0.1820 - acc: 0.9523 - val_loss: 0.2996 - val_acc: 0.9074
Epoch 7/10
92/92 [==============================] - 2s 21ms/step - loss: 0.1608 - acc: 0.9591 - val_loss: 0.2952 - val_acc: 0.9074
Epoch 8/10
92/92 [==============================] - 2s 21ms/step - loss: 0.1433 - acc: 0.9639 - val_loss: 0.2919 - val_acc: 0.9087
Epoch 9/10
92/92 [==============================] - 2s 20ms/step - loss: 0.1285 - acc: 0.9700 - val_loss: 0.2892 - val_acc: 0.9114
Epoch 10/10
92/92 [==============================] - 2s 20ms/step - loss: 0.1158 - acc: 0.9765 - val_loss: 0.2871 - val_acc: 0.9101

Start the TensorBoard to view how the metrics change with each epoch and to track other scalar values:

%tensorboard --logdir logs/fit

Check the predictions

Obtain the ordered list of class names from the model predictions:

predicted_batch = model.predict(image_batch)
predicted_id = tf.math.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]
print(predicted_label_batch)
['roses' 'dandelion' 'tulips' 'sunflowers' 'dandelion' 'roses' 'dandelion'
 'roses' 'tulips' 'dandelion' 'tulips' 'tulips' 'sunflowers' 'tulips'
 'dandelion' 'roses' 'daisy' 'tulips' 'dandelion' 'dandelion' 'dandelion'
 'tulips' 'sunflowers' 'roses' 'sunflowers' 'dandelion' 'tulips' 'roses'
 'roses' 'sunflowers' 'tulips' 'sunflowers']

Plot the model predictions:

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)

for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_label_batch[n].title())
  plt.axis('off')
_ = plt.suptitle("Model predictions")

png

Export and reload your model

Now that you've trained the model, export it as a SavedModel for reusing it later.

t = time.time()

export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path)

export_path
2021-08-03 01:24:05.455085: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/saved_models/1627953843/assets
INFO:tensorflow:Assets written to: /tmp/saved_models/1627953843/assets
'/tmp/saved_models/1627953843'

Confirm that you can reload the SavedModel and that the model is able to output the same results:

reloaded = tf.keras.models.load_model(export_path)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
0.0
reloaded_predicted_id = tf.math.argmax(reloaded_result_batch, axis=-1)
reloaded_predicted_label_batch = class_names[reloaded_predicted_id]
print(reloaded_predicted_label_batch)
['roses' 'dandelion' 'tulips' 'sunflowers' 'dandelion' 'roses' 'dandelion'
 'roses' 'tulips' 'dandelion' 'tulips' 'tulips' 'sunflowers' 'tulips'
 'dandelion' 'roses' 'daisy' 'tulips' 'dandelion' 'dandelion' 'dandelion'
 'tulips' 'sunflowers' 'roses' 'sunflowers' 'dandelion' 'tulips' 'roses'
 'roses' 'sunflowers' 'tulips' 'sunflowers']
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(reloaded_predicted_label_batch[n].title())
  plt.axis('off')
_ = plt.suptitle("Model predictions")

png

Next steps

You can use the SavedModel to load for inference or convert it to a TensorFlow Lite model (for on-device machine learning) or a TensorFlow.js model (for machine learning in JavaScript).

Discover more tutorials to learn how to use pre-trained models from TensorFlow Hub on image, text, audio, and video tasks.