CropNet: Phát hiện bệnh hại sắn

Xem trên TensorFlow.org Chạy trong Google Colab Xem trên GitHub Tải xuống sổ ghi chép Xem mô hình TF Hub

Máy tính xách tay này cho thấy làm thế nào để sử dụng CropNet sắn phân loại bệnh mô hình từ TensorFlow Hub. Các phân loại mô hình hình ảnh của lá sắn vào một trong 6 lớp: bạc lá vi khuẩn, bệnh sọc nâu, mite xanh, bệnh khảm, khỏe mạnh, hoặc chưa biết.

Chuyên mục này trình bày cách:

  • Nạp https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 mô hình từ TensorFlow Hub
  • Nạp sắn bộ dữ liệu từ TensorFlow Datasets (TFDS)
  • Phân loại hình ảnh của lá sắn thành 4 loại bệnh hại sắn riêng biệt hoặc là khỏe mạnh hoặc không rõ.
  • Đánh giá tính chính xác của phân loại và xem xét cách mạnh mẽ mô hình này là khi áp dụng cho ra những hình ảnh miền.

Nhập và thiết lập

pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Chức năng trợ giúp để hiển thị các ví dụ

Dataset

Hãy tải các tập dữ liệu sắn từ TFDS

dataset, info = tfds.load('cassava', with_info=True)

Hãy xem thông tin tập dữ liệu để tìm hiểu thêm về nó, chẳng hạn như mô tả và trích dẫn và thông tin về số lượng ví dụ có sẵn

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

Bộ dữ liệu sắn có hình ảnh của lá sắn với 4 bệnh riêng biệt cũng như lá sắn khỏe mạnh. Mô hình có thể dự đoán tất cả các lớp này cũng như lớp thứ sáu cho "ẩn số" khi mô hình không tự tin vào dự đoán của nó.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Trước khi có thể cung cấp dữ liệu vào mô hình, chúng ta cần thực hiện một chút tiền xử lý. Mô hình mong đợi hình ảnh 224 x 224 với giá trị kênh RGB trong [0, 1]. Hãy chuẩn hóa và thay đổi kích thước hình ảnh.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

Hãy xem một vài ví dụ từ tập dữ liệu

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

png

Mô hình

Hãy tải bộ phân loại từ TF Hub và nhận một số dự đoán và xem các dự đoán của mô hình là trên một số ví dụ

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

Đánh giá & độ mạnh

Hãy đo chính xác của phân loại của chúng tôi về một sự chia rẽ của tập dữ liệu. Chúng tôi cũng có thể nhìn vào sự vững mạnh của mô hình bằng cách đánh giá hiệu quả của nó trên một tập dữ liệu phi sắn. Đối với hình ảnh của bộ dữ liệu thực vật khác như iNaturalist hoặc đậu, các mô hình nên hầu như luôn luôn trở lại chưa biết.

Thông số

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Tìm hiểu thêm