Google I/O is a wrap! Catch up on TensorFlow sessions View sessions

Image classification with Model Garden

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow Model Garden package (tensorflow-models) to classify images in the CIFAR dataset.

Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.

This tutorial uses a ResNet model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.

This tutorial demonstrates how to:

  1. Use models from the TensorFlow Models package.
  2. Fine-tune a pre-built ResNet for image classification.
  3. Export the tuned ResNet model.

Setup

Install and import the necessary modules. This tutorial uses the tf-models-nightly version of Model Garden.

pip uninstall -y opencv-python
pip install -q tf-models-nightly
WARNING: Skipping opencv-python as it is not installed.

Import TensorFlow, TensorFlow Datasets, and a few helper libraries.

import pprint
import tempfile

from IPython import display
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds

The tensorflow_models package contains the ResNet vision model, and the official.vision.serving model contains the function to save and export the tuned model.

import tensorflow_models as tfm

# Not in the tfm public API for v2.9. Will be available as `vision.serving` in v2.10
from official.vision.serving import export_saved_model_lib
/usr/local/lib/python3.7/dist-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.10.0-dev20220427). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

Configure the ResNet-18 model for the Cifar-10 dataset

The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.

In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.

Use the resnet_imagenet factory configuration, as defined by tfm.vision.configs.image_classification.image_classification_imagenet. The configuration is set up to train ResNet to converge on ImageNet.

exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds_info = tfds.builder(tfds_name ).info
ds_info
tfds.core.DatasetInfo(
    name='cifar10',
    version=3.0.2,
    description='The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.',
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    features=FeaturesDict({
        'id': Text(shape=(), dtype=tf.string),
        'image': Image(shape=(32, 32, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=60000,
    splits={
        'test': 10000,
        'train': 50000,
    },
    supervised_keys=('image', 'label'),
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Learning multiple layers of features from tiny images},
        institution = {},
        year = {2009}
    }""",
    redistribution_info=,
)

Adjust the model and dataset configurations so that it works with Cifar-10 (cifar10).

# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18

# Configure training and testing data
batch_size = 128

exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size

exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size

Adjust the trainer configuration.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if 'GPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'TPU'
else:
  print('Running on CPU is slow, so only train it for a few steps.')
  device = 'CPU'

if device=='CPU':
  train_steps = 20
  exp_config.trainer.steps_per_loop = 5
else:
  train_steps=5000
  exp_config.trainer.steps_per_loop = 100

exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps =  ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
Running on CPU is slow, so only train it for a few steps.

Print the modified configuration.

pprint.pprint(exp_config.as_dict())

display.Javascript("google.colab.output.setIframeHeight('300px');")
{'runtime': {'all_reduce_alg': None,
             'batchnorm_spatial_persistent': False,
             'dataset_num_private_threads': None,
             'default_shard_dim': -1,
             'distribution_strategy': 'mirrored',
             'enable_xla': True,
             'gpu_thread_mode': None,
             'loss_scale': None,
             'mixed_precision_dtype': None,
             'num_cores_per_replica': 1,
             'num_gpus': 0,
             'num_packs': 1,
             'per_gpu_thread_count': 0,
             'run_eagerly': False,
             'task_index': -1,
             'tpu': None,
             'tpu_enable_xla_dynamic_padder': None,
             'worker_hosts': None},
 'task': {'differential_privacy_config': None,
          'evaluation': {'top_k': 5},
          'init_checkpoint': None,
          'init_checkpoint_modules': 'all',
          'losses': {'l2_weight_decay': 0.0001,
                     'label_smoothing': 0.0,
                     'loss_weight': 1.0,
                     'one_hot': True,
                     'soft_labels': False},
          'model': {'add_head_batch_norm': False,
                    'backbone': {'resnet': {'bn_trainable': True,
                                            'depth_multiplier': 1.0,
                                            'model_id': 18,
                                            'replace_stem_max_pool': False,
                                            'resnetd_shortcut': False,
                                            'scale_stem': True,
                                            'se_ratio': 0.0,
                                            'stem_type': 'v0',
                                            'stochastic_depth_drop_rate': 0.0},
                                 'type': 'resnet'},
                    'dropout_rate': 0.0,
                    'input_size': [32, 32, 3],
                    'kernel_initializer': 'random_uniform',
                    'norm_activation': {'activation': 'relu',
                                        'norm_epsilon': 1e-05,
                                        'norm_momentum': 0.9,
                                        'use_sync_bn': False},
                    'num_classes': 10},
          'model_output_keys': [],
          'name': None,
          'train_data': {'aug_policy': None,
                         'aug_rand_hflip': True,
                         'aug_type': None,
                         'block_length': 1,
                         'cache': False,
                         'color_jitter': 0.0,
                         'cycle_length': 10,
                         'decode_jpeg_only': True,
                         'decoder': {'simple_decoder': {'mask_binarize_threshold': None,
                                                        'regenerate_source_id': False},
                                     'type': 'simple_decoder'},
                         'deterministic': None,
                         'drop_remainder': True,
                         'dtype': 'float32',
                         'enable_tf_data_service': False,
                         'file_type': 'tfrecord',
                         'global_batch_size': 128,
                         'image_field_key': 'image/encoded',
                         'input_path': '',
                         'is_multilabel': False,
                         'is_training': True,
                         'label_field_key': 'image/class/label',
                         'mixup_and_cutmix': None,
                         'prefetch_buffer_size': None,
                         'randaug_magnitude': 10,
                         'random_erasing': None,
                         'seed': None,
                         'sharding': True,
                         'shuffle_buffer_size': 10000,
                         'tf_data_service_address': None,
                         'tf_data_service_job_name': None,
                         'tfds_as_supervised': False,
                         'tfds_data_dir': '',
                         'tfds_name': 'cifar10',
                         'tfds_skip_decoding_feature': '',
                         'tfds_split': 'train'},
          'validation_data': {'aug_policy': None,
                              'aug_rand_hflip': True,
                              'aug_type': None,
                              'block_length': 1,
                              'cache': False,
                              'color_jitter': 0.0,
                              'cycle_length': 10,
                              'decode_jpeg_only': True,
                              'decoder': {'simple_decoder': {'mask_binarize_threshold': None,
                                                             'regenerate_source_id': False},
                                          'type': 'simple_decoder'},
                              'deterministic': None,
                              'drop_remainder': True,
                              'dtype': 'float32',
                              'enable_tf_data_service': False,
                              'file_type': 'tfrecord',
                              'global_batch_size': 128,
                              'image_field_key': 'image/encoded',
                              'input_path': '',
                              'is_multilabel': False,
                              'is_training': False,
                              'label_field_key': 'image/class/label',
                              'mixup_and_cutmix': None,
                              'prefetch_buffer_size': None,
                              'randaug_magnitude': 10,
                              'random_erasing': None,
                              'seed': None,
                              'sharding': True,
                              'shuffle_buffer_size': 10000,
                              'tf_data_service_address': None,
                              'tf_data_service_job_name': None,
                              'tfds_as_supervised': False,
                              'tfds_data_dir': '',
                              'tfds_name': 'cifar10',
                              'tfds_skip_decoding_feature': '',
                              'tfds_split': 'test'} },
 'trainer': {'allow_tpu_summary': False,
             'best_checkpoint_eval_metric': '',
             'best_checkpoint_export_subdir': '',
             'best_checkpoint_metric_comp': 'higher',
             'checkpoint_interval': 20,
             'continuous_eval_timeout': 3600,
             'eval_tf_function': True,
             'eval_tf_while_loop': False,
             'loss_upper_bound': 1000000.0,
             'max_to_keep': 5,
             'optimizer_config': {'ema': None,
                                  'learning_rate': {'cosine': {'alpha': 0.0,
                                                               'decay_steps': 20,
                                                               'initial_learning_rate': 0.1,
                                                               'name': 'CosineDecay',
                                                               'offset': 0},
                                                    'type': 'cosine'},
                                  'optimizer': {'sgd': {'clipnorm': None,
                                                        'clipvalue': None,
                                                        'decay': 0.0,
                                                        'global_clipnorm': None,
                                                        'momentum': 0.9,
                                                        'name': 'SGD',
                                                        'nesterov': False},
                                                'type': 'sgd'},
                                  'warmup': {'linear': {'name': 'linear',
                                                        'warmup_learning_rate': 0,
                                                        'warmup_steps': 100},
                                             'type': 'linear'} },
             'recovery_begin_steps': 0,
             'recovery_max_trials': 0,
             'steps_per_loop': 5,
             'summary_interval': 100,
             'train_steps': 20,
             'train_tf_function': True,
             'train_tf_while_loop': True,
             'validation_interval': 1000,
             'validation_steps': 78,
             'validation_summary_subdir': 'validation'} }
<IPython.core.display.Javascript object>

Set up the distribution strategy.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if exp_config.runtime.mixed_precision_dtype == tf.float16:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

if 'GPU' in ''.join(logical_device_names):
  distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
  tf.tpu.experimental.initialize_tpu_system()
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
  print('Warning: this will be really slow.')
  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
Warning: this will be really slow.

Create the Task object (tfm.core.base_task.Task) from the config_definitions.TaskConfig.

The Task object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment.

with distribution_strategy.scope():
  model_dir = tempfile.mkdtemp()
  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)

tf.keras.utils.plot_model(task.build_model(), show_shapes=True)

png

for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  print()
  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')
  print(f'labels.shape: {str(labels.shape):16}  labels.dtype: {labels.dtype!r}')
images.shape: (128, 32, 32, 3)  images.dtype: tf.float32
labels.shape: (128,)            labels.dtype: tf.int32

Visualize the training data

The dataloader applies a z-score normalization using preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB), so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range.

plt.hist(images.numpy().flatten());

png

Use ds_info (which is an instance of tfds.core.DatasetInfo) to lookup the text descriptions of each class ID.

label_info = ds_info.features['label']
label_info.int2str(1)
'automobile'

Visualize a batch of the data.

def show_batch(images, labels, predictions=None):
  plt.figure(figsize=(10, 10))
  min = images.numpy().min()
  max = images.numpy().max()
  delta = max - min

  for i in range(12):
    plt.subplot(6, 6, i + 1)
    plt.imshow((images[i]-min) / delta)
    if predictions is None:
      plt.title(label_info.int2str(labels[i]))
    else:
      if labels[i] == predictions[i]:
        color = 'g'
      else:
        color = 'r'
      plt.title(label_info.int2str(predictions[i]), color=color)
    plt.axis("off")
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  show_batch(images, labels)
<Figure size 720x720 with 0 Axes>

png

Visualize the testing data

Visualize a batch of images from the validation dataset.

plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
  show_batch(images, labels)
<Figure size 720x720 with 0 Axes>

png

Train and evaluate

model, eval_logs = tfm.core.train_lib.run_experiment(
    distribution_strategy=distribution_strategy,
    task=task,
    mode='train_and_eval',
    params=exp_config,
    model_dir=model_dir,
    run_post_eval=True)
restoring or initializing model...
initialized model.
train | step:      0 | training until step 20...
train | step:      5 | steps/sec:    0.2 | output: 
    {'accuracy': 0.1125,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.4953125,
     'training_loss': 2.8116703}
saved checkpoint to /tmp/tmpov3_cdii/ckpt-5.
train | step:     10 | steps/sec:    0.2 | output: 
    {'accuracy': 0.0890625,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.45625,
     'training_loss': 2.902588}
train | step:     15 | steps/sec:    0.2 | output: 
    {'accuracy': 0.090625,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.525,
     'training_loss': 2.8092024}
train | step:     20 | steps/sec:    0.2 | output: 
    {'accuracy': 0.0828125,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.496875,
     'training_loss': 2.827601}
 eval | step:     20 | running 78 steps of evaluation...
 eval | step:     20 | eval time:   21.8 sec | output: 
    {'accuracy': 0.09304888,
     'top_5_accuracy': 0.48607773,
     'validation_loss': 2.6170642}
saved checkpoint to /tmp/tmpov3_cdii/ckpt-20.
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/nn_ops.py:5219: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/nn_ops.py:5219: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
tf.keras.utils.plot_model(model, show_shapes=True)

png

Print the accuracy, top_5_accuracy, and validation_loss evaluation metrics.

for key, value in eval_logs.items():
    print(f'{key:20}: {value.numpy():.3f}')
accuracy            : 0.093
top_5_accuracy      : 0.486
validation_loss     : 2.617

Run a batch of the processed training data through the model, and view the results

for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  predictions = model.predict(images)
  predictions = tf.argmax(predictions, axis=-1)

show_batch(images, labels, tf.cast(predictions, tf.int32))

if device=='CPU':
  plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
4/4 [==============================] - 1s 122ms/step

png

Export a SavedModel

The keras.Model object returned by train_lib.run_experiment expects the data to be normalized by the dataset loader using the same mean and variance statiscics in preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB). This export function handles those details, so you can pass tf.uint8 images and get the correct results.

# Saving and exporting the trained model
export_saved_model_lib.export_inference_graph(
    input_type='image_tensor',
    batch_size=1,
    input_image_size=[32, 32],
    params=exp_config,
    checkpoint_path=tf.train.latest_checkpoint(model_dir),
    export_dir='./export/')
WARNING:absl:Found untraced functions such as inference_for_tflite, inference_from_image_bytes, inference_from_tf_example, _jit_compiled_convolution_op, conv2d_layer_call_fn while saving (showing 5 of 64). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./export/assets
INFO:tensorflow:Assets written to: ./export/assets

Test the exported model.

# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']

Visualize the predictions.

plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
  predictions = []
  for image in data['image']:
    index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
    predictions.append(index)
  show_batch(data['image'], data['label'], predictions)

  if device=='CPU':
    plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
<Figure size 720x720 with 0 Axes>

png