Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

CropNet: rilevamento della malattia della manioca

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica notebook

Questo taccuino mostra come utilizzare il modello di classificazione della malattia della manioca CropNet da TensorFlow Hub . Il modello classifica le immagini delle foglie di manioca in una delle 6 classi: batterio batterico, malattia della striscia marrone, acaro verde, malattia del mosaico, sano o sconosciuto .

Questa colab mostra come:

  • Carica il modello https: //tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 da TensorFlow Hub
  • Carica il set di dati della manioca da TensorFlow Datasets (TFDS)
  • Classificare le immagini delle foglie di manioca in 4 categorie distinte di malattie della manioca o come sane o sconosciute.
  • Valuta l' accuratezza del classificatore e osserva quanto è robusto il modello quando applicato a immagini esterne al dominio.

Importazioni e configurazione

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

def plot(examples, predictions=None):
  # Get the images, labels, and optionally predictions
  images = examples['image']
  labels = examples['label']
  batch_size = len(images)
  if predictions is None:
    predictions = batch_size * [None]

  # Configure the layout of the grid
  x = np.ceil(np.sqrt(batch_size))
  y = np.ceil(batch_size / x)
  fig = plt.figure(figsize=(x * 6, y * 7))

  for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
    # Render the image
    ax = fig.add_subplot(x, y, i+1)
    ax.imshow(image, aspect='auto')
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])

    # Display the label and optionally prediction
    x_label = 'Label: ' + name_map[class_names[label]]
    if prediction is not None:
      x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
      ax.xaxis.label.set_color('green' if label == prediction else 'red')
    ax.set_xlabel(x_label)

  plt.show()

dataset

Carichiamo il set di dati della manioca da TFDS

dataset, info = tfds.load('cassava', with_info=True)
Downloading and preparing dataset cassava/0.1.0 (download: 1.26 GiB, generated: Unknown size, total: 1.26 GiB) to /home/kbuilder/tensorflow_datasets/cassava/0.1.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete2AVNJC/cassava-train.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete2AVNJC/cassava-test.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete2AVNJC/cassava-validation.tfrecord
Dataset cassava downloaded and prepared to /home/kbuilder/tensorflow_datasets/cassava/0.1.0. Subsequent calls will reuse this data.

Diamo un'occhiata alle informazioni sul set di dati per saperne di più, come la descrizione e la citazione e le informazioni su quanti esempi sono disponibili

info
tfds.core.DatasetInfo(
    name='cassava',
    version=0.1.0,
    description='Cassava consists of leaf images for the cassava plant depicting healthy and
four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
Dataset consists of a total of 9430 labelled images.
The 9430 labelled images are split into a training set (5656), a test set(1885)
and a validation set (1889). The number of images per class are unbalanced with
the two disease classes CMD and CBSD having 72% of the images.',
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    total_num_examples=9430,
    splits={
        'test': 1885,
        'train': 5656,
        'validation': 1889,
    },
    supervised_keys=('image', 'label'),
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
    redistribution_info=,
)

Il set di dati della manioca contiene immagini di foglie di manioca con 4 malattie distinte e foglie di manioca sane. Il modello può prevedere tutte queste classi così come la sesta classe per "sconosciuto" quando il modello non è sicuro della sua previsione.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Prima di poter fornire i dati al modello, è necessario eseguire un po 'di pre-elaborazione. Il modello prevede immagini 224 x 224 con valori di canale RGB in [0, 1]. Normalizziamo e ridimensioniamo le immagini.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

Diamo un'occhiata ad alcuni esempi dal set di dati

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

png

Modello

Carichiamo il classificatore da TF-Hub e otteniamo alcune previsioni e vediamo le previsioni del modello su alcuni esempi

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

Valutazione e robustezza

Misuriamo l' accuratezza del nostro classificatore su una divisione del set di dati. Possiamo anche esaminare la robustezza del modello valutandone le prestazioni su un set di dati non di manioca. Per l'immagine di altri set di dati di piante come iNaturalist o bean, il modello dovrebbe quasi sempre restituire sconosciuto .



DATASET = 'cassava'  
DATASET_SPLIT = 'test' 
BATCH_SIZE =  32 
MAX_EXAMPLES = 1000 

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Per saperne di più

  • Ulteriori informazioni sul modello su TensorFlow Hub: https: //tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
  • Scopri come creare un classificatore di immagini personalizzato in esecuzione su un telefono cellulare con ML Kit con la versione TensorFlow Lite di questo modello .