Cette page a été traduite par l'API Cloud Translation.
Switch to English

CropNet: détection de la maladie du manioc

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher sur GitHub Télécharger le carnet Voir le modèle TF Hub

Ce cahier montre comment utiliser le modèle de classification de la maladie du manioc CropNet de TensorFlow Hub . Le modèle classe les images de feuilles de manioc dans l'une des 6 classes suivantes: brûlure bactérienne, maladie des stries brunes, acarien vert, maladie de la mosaïque, saine ou inconnue .

Ce colab montre comment:

  • Chargez le modèle https: //tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 à partir de TensorFlow Hub
  • Chargez l'ensemble de données sur le manioc à partir des ensembles de données TensorFlow (TFDS)
  • Classez les images de feuilles de manioc en 4 catégories distinctes de maladies du manioc ou comme saines ou inconnues.
  • Évaluez la précision du classificateur et regardez la robustesse du modèle lorsqu'il est appliqué à des images hors domaine.

Importations et configuration

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

def plot(examples, predictions=None):
  # Get the images, labels, and optionally predictions
  images = examples['image']
  labels = examples['label']
  batch_size = len(images)
  if predictions is None:
    predictions = batch_size * [None]

  # Configure the layout of the grid
  x = np.ceil(np.sqrt(batch_size))
  y = np.ceil(batch_size / x)
  fig = plt.figure(figsize=(x * 6, y * 7))

  for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
    # Render the image
    ax = fig.add_subplot(x, y, i+1)
    ax.imshow(image, aspect='auto')
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])

    # Display the label and optionally prediction
    x_label = 'Label: ' + name_map[class_names[label]]
    if prediction is not None:
      x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
      ax.xaxis.label.set_color('green' if label == prediction else 'red')
    ax.set_xlabel(x_label)

  plt.show()

Base de données

Chargons l'ensemble de données sur le manioc de TFDS

dataset, info = tfds.load('cassava', with_info=True)
Downloading and preparing dataset cassava/0.1.0 (download: 1.26 GiB, generated: Unknown size, total: 1.26 GiB) to /home/kbuilder/tensorflow_datasets/cassava/0.1.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete5MDKIM/cassava-train.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete5MDKIM/cassava-test.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete5MDKIM/cassava-validation.tfrecord
Dataset cassava downloaded and prepared to /home/kbuilder/tensorflow_datasets/cassava/0.1.0. Subsequent calls will reuse this data.

Jetons un coup d'œil aux informations sur l'ensemble de données pour en savoir plus, comme la description et la citation et des informations sur le nombre d'exemples disponibles

info
tfds.core.DatasetInfo(
    name='cassava',
    version=0.1.0,
    description='Cassava consists of leaf images for the cassava plant depicting healthy and
four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
Dataset consists of a total of 9430 labelled images.
The 9430 labelled images are split into a training set (5656), a test set(1885)
and a validation set (1889). The number of images per class are unbalanced with
the two disease classes CMD and CBSD having 72% of the images.',
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    total_num_examples=9430,
    splits={
        'test': 1885,
        'train': 5656,
        'validation': 1889,
    },
    supervised_keys=('image', 'label'),
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
    redistribution_info=,
)

L'ensemble de données sur le manioc contient des images de feuilles de manioc présentant 4 maladies distinctes ainsi que des feuilles de manioc saines. Le modèle peut prédire toutes ces classes ainsi que la sixième classe pour «inconnu» lorsque le modèle n'est pas sûr de sa prédiction.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Avant de pouvoir alimenter le modèle en données, nous devons effectuer un peu de prétraitement. Le modèle attend des images 224 x 224 avec des valeurs de canal RVB en [0, 1]. Normalisons et redimensionnons les images.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

Jetons un coup d'œil à quelques exemples de l'ensemble de données

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/ipykernel_launcher.py:17: MatplotlibDeprecationWarning: Passing non-integers as three-element position specification is deprecated since 3.3 and will be removed two minor releases later.

png

Modèle

Chargeons le classificateur de TF-Hub et obtenons quelques prédictions et voyons les prédictions du modèle sur quelques exemples

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/ipykernel_launcher.py:17: MatplotlibDeprecationWarning: Passing non-integers as three-element position specification is deprecated since 3.3 and will be removed two minor releases later.

png

Évaluation et robustesse

Mesurons la précision de notre classificateur sur une division de l'ensemble de données. Nous pouvons également examiner la robustesse du modèle en évaluant ses performances sur un ensemble de données non manioc. Pour l'image d'autres ensembles de données de plantes comme iNaturalist ou beans, le modèle doit presque toujours renvoyer unknown .



DATASET = 'cassava'  
DATASET_SPLIT = 'test' 
BATCH_SIZE =  32 
MAX_EXAMPLES = 1000 

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Apprendre encore plus

  • En savoir plus sur le modèle sur TensorFlow Hub: https: //tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
  • Apprenez à créer un classificateur d'images personnalisé s'exécutant sur un téléphone mobile avec ML Kit avec la version TensorFlow Lite de ce modèle .