CropNet: Cassava Disease Detection

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드 TF Hub 모델보기

이 노트북은 TensorFlow Hub의 CropNet 카사바 질병 분류자 모델을 사용하는 방법을 보여줍니다. 이 모델은 카사바 잎의 이미지를 세균성 마름병, 갈색 줄무늬병, 녹색 진드기, 모자이크병, 건강함 또는 알 수 없음의 6가지 등급 중 하나로 분류합니다.

이 Colab에서는 다음 방법을 보여줍니다.

  • TensorFlow Hub에서 https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 모델을 로드합니다.
  • TFDS(TensorFlow 데이터세트)에서 cassava 데이터세트를 로드합니다.
  • 카사바 잎의 이미지를 4개의 특징적인 카사바 질병 범주 또는 건강하거나 알려지지 않은 상태로 분류합니다.
  • 분류자의 정확성을 평가하고 도메인 외부 이미지에 적용했을 때 모델이 얼마나 강력한지 확인합니다.

가져오기 및 설정

pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
2022-12-14 22:20:32.937357: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:20:32.937456: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:20:32.937465: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Helper function for displaying examples

데이터세트

TFDS에서 cassava 데이터세트를 불러오겠습니다.

dataset, info = tfds.load('cassava', with_info=True)

설명과 인용, 사용 가능한 예제 수에 대한 정보 등 데이터세트에 대해 자세히 알아보겠습니다.

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    file_format=tfrecord,
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

cassava 데이터세트에는 4가지 질병이 있는 카사바 잎 이미지와 건강한 카사바 잎이 있습니다. 모델은 이러한 모든 등급을 예측할 수 있고, 예측에 확신이 없는 경우에는 여섯 번째 등급인 "알 수 없음"으로 분류합니다.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

모델에 데이터를 공급하기 전에 약간의 전처리가 필요합니다. 모델은 RGB 채널 값이 [0, 1] 범위인 224 x 224 이미지를 예상합니다. 이미지를 정규화하고 크기를 조정하겠습니다.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

데이터세트의 몇 가지 예를 살펴보겠습니다.

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

png

모델

TF-Hub에서 분류자를 로드하고 일부 예측값을 가져와서 몇 가지 예에서 모델의 예측값이 어떤지 확인하겠습니다.

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
plot(examples, predictions)

png

평가 및 견고성

일부 데이터세트에서 분류자의 정확성을 측정해 보겠습니다. 또한 카사바가 아닌 데이터세트에서 성능을 평가하여 모델의 견고성을 확인할 수 있습니다. iNaturalist 또는 콩과 같은 다른 식물 데이터세트의 이미지의 경우, 모델은 거의 항상 unknown을 반환해야 합니다.

Parameters

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

자세히 알아보기