Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

CropNet: Deteksi Penyakit Singkong

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Notebook ini menunjukkan cara menggunakan model pengklasifikasi penyakit singkong CropNet dari TensorFlow Hub . Model tersebut mengklasifikasikan gambar daun singkong menjadi salah satu dari 6 kelas: hawar bakteri, penyakit goresan coklat, tungau hijau, penyakit mosaik, sehat, atau tidak diketahui .

Colab ini menunjukkan cara:

  • Muat model https: //tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 dari TensorFlow Hub
  • Muat set data cassava dari TensorFlow Datasets (TFDS)
  • Mengelompokkan gambar daun singkong menjadi 4 kategori penyakit singkong yang berbeda atau sehat atau tidak diketahui.
  • Evaluasi akurasi pengklasifikasi dan lihat seberapa kuat model tersebut saat diterapkan ke gambar di luar domain.

Impor dan penyiapan

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

def plot(examples, predictions=None):
  # Get the images, labels, and optionally predictions
  images = examples['image']
  labels = examples['label']
  batch_size = len(images)
  if predictions is None:
    predictions = batch_size * [None]

  # Configure the layout of the grid
  x = np.ceil(np.sqrt(batch_size))
  y = np.ceil(batch_size / x)
  fig = plt.figure(figsize=(x * 6, y * 7))

  for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
    # Render the image
    ax = fig.add_subplot(x, y, i+1)
    ax.imshow(image, aspect='auto')
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])

    # Display the label and optionally prediction
    x_label = 'Label: ' + name_map[class_names[label]]
    if prediction is not None:
      x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
      ax.xaxis.label.set_color('green' if label == prediction else 'red')
    ax.set_xlabel(x_label)

  plt.show()

Himpunan data

Mari muat dataset cassava dari TFDS

dataset, info = tfds.load('cassava', with_info=True)
Downloading and preparing dataset cassava/0.1.0 (download: 1.26 GiB, generated: Unknown size, total: 1.26 GiB) to /home/kbuilder/tensorflow_datasets/cassava/0.1.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete2AVNJC/cassava-train.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete2AVNJC/cassava-test.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete2AVNJC/cassava-validation.tfrecord
Dataset cassava downloaded and prepared to /home/kbuilder/tensorflow_datasets/cassava/0.1.0. Subsequent calls will reuse this data.

Mari kita lihat info set data untuk mempelajarinya lebih lanjut, seperti deskripsi dan kutipan serta informasi tentang berapa banyak contoh yang tersedia.

info
tfds.core.DatasetInfo(
    name='cassava',
    version=0.1.0,
    description='Cassava consists of leaf images for the cassava plant depicting healthy and
four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
Dataset consists of a total of 9430 labelled images.
The 9430 labelled images are split into a training set (5656), a test set(1885)
and a validation set (1889). The number of images per class are unbalanced with
the two disease classes CMD and CBSD having 72% of the images.',
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    total_num_examples=9430,
    splits={
        'test': 1885,
        'train': 5656,
        'validation': 1889,
    },
    supervised_keys=('image', 'label'),
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
    redistribution_info=,
)

Dataset ubi kayu memiliki gambar daun ubi kayu dengan 4 penyakit berbeda serta daun ubi kayu yang sehat. Model dapat memprediksi semua kelas ini serta kelas keenam untuk "tidak diketahui" jika model tidak yakin dengan prediksinya.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Sebelum kita dapat memasukkan data ke model, kita perlu melakukan sedikit preprocessing. Model mengharapkan gambar 224 x 224 dengan nilai saluran RGB di [0, 1]. Mari menormalkan dan mengubah ukuran gambar.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

Mari kita lihat beberapa contoh dari dataset

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

png

Model

Mari kita muat pengklasifikasi dari TF-Hub dan dapatkan beberapa prediksi dan lihat prediksi model tersebut pada beberapa contoh

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

Evaluasi & ketahanan

Mari kita ukur keakuratan pengklasifikasi kita pada pemisahan kumpulan data. Kami juga dapat melihat kekuatan model dengan mengevaluasi kinerjanya pada kumpulan data non-singkong. Untuk gambar kumpulan data tanaman lain seperti iNaturalist atau kacang-kacangan, model harus selalu menampilkan tidak diketahui .



DATASET = 'cassava'  
DATASET_SPLIT = 'test' 
BATCH_SIZE =  32 
MAX_EXAMPLES = 1000 

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Belajarlah lagi

  • Pelajari lebih lanjut tentang model di TensorFlow Hub: https: //tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
  • Pelajari cara membuat pengklasifikasi gambar kustom yang dijalankan di ponsel dengan ML Kit dengan versi TensorFlow Lite dari model ini .