CropNet: Cassava Disease Detection

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드 TF Hub 모델보기

이 노트북은 TensorFlow Hub의 CropNet 카사바 질병 분류자 모델을 사용하는 방법을 보여줍니다. 이 모델은 카사바 잎의 이미지를 세균성 마름병, 갈색 줄무늬병, 녹색 진드기, 모자이크병, 건강함 또는 알 수 없음의 6가지 등급 중 하나로 분류합니다.

이 Colab에서는 다음 방법을 보여줍니다.

  • TensorFlow Hub에서 https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 모델을 로드합니다.
  • TFDS(TensorFlow 데이터세트)에서 cassava 데이터세트를 로드합니다.
  • 카사바 잎의 이미지를 4개의 특징적인 카사바 질병 범주 또는 건강하거나 알려지지 않은 상태로 분류합니다.
  • 분류자의 정확성을 평가하고 도메인 외부 이미지에 적용했을 때 모델이 얼마나 강력한지 확인합니다.

가져오기 및 설정

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Helper function for displaying examples

데이터세트

TFDS에서 cassava 데이터세트를 불러오겠습니다.

dataset, info = tfds.load('cassava', with_info=True)

설명과 인용, 사용 가능한 예제 수에 대한 정보 등 데이터세트에 대해 자세히 알아보겠습니다.

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

cassava 데이터세트에는 4가지 질병이 있는 카사바 잎 이미지와 건강한 카사바 잎이 있습니다. 모델은 이러한 모든 등급을 예측할 수 있고, 예측에 확신이 없는 경우에는 여섯 번째 등급인 "알 수 없음"으로 분류합니다.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

모델에 데이터를 공급하기 전에 약간의 전처리가 필요합니다. 모델은 RGB 채널 값이 [0, 1] 범위인 224 x 224 이미지를 예상합니다. 이미지를 정규화하고 크기를 조정하겠습니다.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

데이터세트의 몇 가지 예를 살펴보겠습니다.

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
/home/kbuilder/.local/lib/python3.6/site-packages/ipykernel_launcher.py:17: MatplotlibDeprecationWarning: Passing non-integers as three-element position specification is deprecated since 3.3 and will be removed two minor releases later.

png

모델

TF-Hub에서 분류자를 로드하고 일부 예측값을 가져와서 몇 가지 예에서 모델의 예측값이 어떤지 확인하겠습니다.

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
/home/kbuilder/.local/lib/python3.6/site-packages/ipykernel_launcher.py:17: MatplotlibDeprecationWarning: Passing non-integers as three-element position specification is deprecated since 3.3 and will be removed two minor releases later.

png

평가 및 견고성

일부 데이터세트에서 분류자의 정확성을 측정해 보겠습니다. 또한 카사바가 아닌 데이터세트에서 성능을 평가하여 모델의 견고성을 확인할 수 있습니다. iNaturalist 또는 콩과 같은 다른 식물 데이터세트의 이미지의 경우, 모델은 거의 항상 unknown을 반환해야 합니다.

Parameters

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

자세히 알아보기

  • TensorFlow Hub의 모델에 대해 자세히 알아보기: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
  • 이 모델의 TensorFlow Lite 버전을 사용하여 ML 키트를 통해 휴대전화에서 실행되는 사용자 정의 이미지 분류자를 빌드하는 방법을 알아보세요.