Image classification with Model Garden

This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow Model Garden package (tensorflow-models) to classify images in the CIFAR dataset.

Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.

This tutorial uses a ResNet model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.

This tutorial demonstrates how to:

  1. Use models from the TensorFlow Models package.
  2. Fine-tune a pre-built ResNet for image classification.
  3. Export the tuned ResNet model.


Install and import the necessary modules.

pip install -U -q "tf-models-official"

Import TensorFlow, TensorFlow Datasets, and a few helper libraries.

import pprint
import tempfile

from IPython import display
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds

The tensorflow_models package contains the ResNet vision model, and the model contains the function to save and export the tuned model.

import tensorflow_models as tfm

# These are not in the tfm public API for v2.9. They will be available in v2.10
from import export_saved_model_lib
import official.core.train_lib

Configure the ResNet-18 model for the Cifar-10 dataset

The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.

In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.

Use the resnet_imagenet factory configuration, as defined by The configuration is set up to train ResNet to converge on ImageNet.

exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds,ds_info = tfds.load(

Adjust the model and dataset configurations so that it works with Cifar-10 (cifar10).

# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18

# Configure training and testing data
batch_size = 128

exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size

exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size

Adjust the trainer configuration.

logical_device_names = [ for logical_device in tf.config.list_logical_devices()]

if 'GPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'TPU'
  print('Running on CPU is slow, so only train for a few steps.')
  device = 'CPU'

if device=='CPU':
  train_steps = 20
  exp_config.trainer.steps_per_loop = 5
  exp_config.trainer.steps_per_loop = 100

exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps =  ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100

Print the modified configuration.



Set up the distribution strategy.

logical_device_names = [ for logical_device in tf.config.list_logical_devices()]

if exp_config.runtime.mixed_precision_dtype == tf.float16:

if 'GPU' in ''.join(logical_device_names):
  distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
  print('Warning: this will be really slow.')
  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])

Create the Task object (tfm.core.base_task.Task) from the config_definitions.TaskConfig.

The Task object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment.

with distribution_strategy.scope():
  model_dir = tempfile.mkdtemp()
  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)

#  tf.keras.utils.plot_model(task.build_model(), show_shapes=True)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')
  print(f'labels.shape: {str(labels.shape):16}  labels.dtype: {labels.dtype!r}')

Visualize the training data

The dataloader applies a z-score normalization using preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB), so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range.


Use ds_info (which is an instance of tfds.core.DatasetInfo) to lookup the text descriptions of each class ID.

label_info = ds_info.features['label']

Visualize a batch of the data.

def show_batch(images, labels, predictions=None):
  plt.figure(figsize=(10, 10))
  min = images.numpy().min()
  max = images.numpy().max()
  delta = max - min

  for i in range(12):
    plt.subplot(6, 6, i + 1)
    plt.imshow((images[i]-min) / delta)
    if predictions is None:
      if labels[i] == predictions[i]:
        color = 'g'
        color = 'r'
      plt.title(label_info.int2str(predictions[i]), color=color)
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  show_batch(images, labels)

Visualize the testing data

Visualize a batch of images from the validation dataset.

plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
  show_batch(images, labels)

Train and evaluate

model, eval_logs = tfm.core.train_lib.run_experiment(
#  tf.keras.utils.plot_model(model, show_shapes=True)

Print the accuracy, top_5_accuracy, and validation_loss evaluation metrics.

for key, value in eval_logs.items():
    if isinstance(value, tf.Tensor):
      value = value.numpy()
    print(f'{key:20}: {value:.3f}')

Run a batch of the processed training data through the model, and view the results

for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  predictions = model.predict(images)
  predictions = tf.argmax(predictions, axis=-1)

show_batch(images, labels, tf.cast(predictions, tf.int32))

if device=='CPU':
  plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')

Export a SavedModel

The keras.Model object returned by train_lib.run_experiment expects the data to be normalized by the dataset loader using the same mean and variance statiscics in preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB). This export function handles those details, so you can pass tf.uint8 images and get the correct results.

# Saving and exporting the trained model
    input_image_size=[32, 32],

Test the exported model.

# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']

Visualize the predictions.

plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
  predictions = []
  for image in data['image']:
    index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
  show_batch(data['image'], data['label'], predictions)

  if device=='CPU':
    plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')