今日のローカルTensorFlowEverywhereイベントの出欠確認!

CropNet: Cassava Disease Detection

TensorFlow.orgで表示{ Google Colabで実行 GitHub でソースを表示 ノートブックをダウンロード TFハブモデルを参照してください

このノートブックでは、TensorFlow Hub の CropNet キャッサバの病気の分類モデルの使用方法を説明します。このモデルはキャッサバの葉の画像を 6 つのクラスのいずれかに分類します。クラスは斑点細菌病、褐色条斑病、緑ダニ、モザイク病、健康、不明です。

この Colab では、以下の方法を実演します。

  • TensorFlow Hub からモデル https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 を読み込む。
  • TensorFlow Datasets (TFDS) からキャッサバデータセットを読み込む。
  • キャッサバの葉の画像を、4 つの異なるキャッサバの病気のカテゴリ、あるいは健康または不明として分類する。
  • 分類器の精度を評価し、ドメイン外の画像を適用した際のモデルのロバスト性を検査する。

インポートとセットアップ

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Helper function for displaying examples

データセット

TFDS からキャッサバデータセットを読み込みます。

dataset, info = tfds.load('cassava', with_info=True)

データセットの情報を見て、説明や引用、例の数などの詳細情報を確認しましょう。

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

キャッサバデータセットには、健康なキャッサバの葉とそれぞれ異なる病気を持つ 4 枚のキャッサバの葉の画像があります。モデルはこれらの全てのクラスの予測ができ、予測に自信がない場合、モデルは 6 番目の "unknown(不明)" のクラスを予測します。

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

データをモデルに送る前に、少し前処理をする必要があります。モデルは RGB チャンネル値が [0, 1] の 224 x 224 の画像を想定しています。画像を正規化してサイズを変更しましょう。

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

データセットからいくつかの例を見てみましょう。

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
/home/kbuilder/.local/lib/python3.6/site-packages/ipykernel_launcher.py:17: MatplotlibDeprecationWarning: Passing non-integers as three-element position specification is deprecated since 3.3 and will be removed two minor releases later.

png

モデル

TF-Hub から分類器を読み込んで予測値をいくつか取得し、複数の例のモデルの予測値を見てみましょう。

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
/home/kbuilder/.local/lib/python3.6/site-packages/ipykernel_launcher.py:17: MatplotlibDeprecationWarning: Passing non-integers as three-element position specification is deprecated since 3.3 and will be removed two minor releases later.

png

評価とロバスト性

データセットを分割した場合の分類器の精度を測定してみましょう。また、キャッサバ以外のデータセットで性能を評価して、モデルのロバスト性を調べることもできます。iNaturalist や豆など、他の植物のデータセット画像を使用すると、モデルは殆ど全ての画像に対して不明を返すはずです。

Parameters

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

詳細情報

  • TensorFlow Hub のモデルに関する詳細情報: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
  • このモデルの TensorFlow Lite 版を使い、ML Kit を使用して携帯電話で動作するカスタム画像分類器の構築方法を学びましょう。