![]() | ![]() | ![]() | ![]() | ![]() |
В этой записной книжке показано, как использовать модель классификатора болезней маниоки CropNet от TensorFlow Hub . Модель классифицирует изображения листьев маниоки по одному из 6 классов: бактериальный ожог, болезнь коричневых полос, зеленый клещ, мозаичная болезнь, здоровые или неизвестные .
Этот колаб демонстрирует, как:
- Загрузите модель https: //tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 из TensorFlow Hub
- Загрузите набор данных маниоки из наборов данных TensorFlow (TFDS)
- Классифицируйте изображения листьев маниоки по 4 различным категориям болезней маниоки или как здоровые или неизвестные.
- Оцените точность классификатора и посмотрите, насколько надежна модель при применении к изображениям вне домена.
Импорт и настройка
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
Вспомогательная функция для отображения примеров
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(x, y, i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
Набор данных
Загрузим набор данных маниоки из TFDS
dataset, info = tfds.load('cassava', with_info=True)
Downloading and preparing dataset cassava/0.1.0 (download: 1.26 GiB, generated: Unknown size, total: 1.26 GiB) to /home/kbuilder/tensorflow_datasets/cassava/0.1.0... Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete47F7RH/cassava-train.tfrecord Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete47F7RH/cassava-test.tfrecord Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cassava/0.1.0.incomplete47F7RH/cassava-validation.tfrecord Dataset cassava downloaded and prepared to /home/kbuilder/tensorflow_datasets/cassava/0.1.0. Subsequent calls will reuse this data.
Давайте посмотрим на информацию о наборе данных, чтобы узнать о ней больше, например, на описание и цитату, а также на информацию о том, сколько примеров доступно.
info
tfds.core.DatasetInfo( name='cassava', version=0.1.0, description='Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images.', homepage='https://www.kaggle.com/c/cassava-disease/overview', features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'image/filename': Text(shape=(), dtype=tf.string), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), total_num_examples=9430, splits={ 'test': 1885, 'train': 5656, 'validation': 1889, }, supervised_keys=('image', 'label'), citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", redistribution_info=, )
В наборе данных маниоки есть изображения листьев маниоки с 4 различными заболеваниями, а также здоровые листья маниоки. Модель может предсказать все эти классы, а также шестой класс для «неизвестного», когда модель не уверена в своем предсказании.
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
Прежде чем мы сможем передать данные в модель, нам нужно провести небольшую предварительную обработку. Модель ожидает изображения 224 x 224 со значениями канала RGB в [0, 1]. Давайте нормализуем и изменим размер изображений.
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
Давайте посмотрим на несколько примеров из набора данных
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/ipykernel_launcher.py:17: MatplotlibDeprecationWarning: Passing non-integers as three-element position specification is deprecated since 3.3 and will be removed two minor releases later.
Модель
Давайте загрузим классификатор из TF-Hub, получим некоторые прогнозы и посмотрим, что прогнозы модели есть на нескольких примерах.
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/ipykernel_launcher.py:17: MatplotlibDeprecationWarning: Passing non-integers as three-element position specification is deprecated since 3.3 and will be removed two minor releases later.
Оценка и надежность
Давайте измерим точность нашего классификатора на части набора данных. Мы также можем посмотреть на надежность модели, оценив ее производительность на наборе данных, отличном от маниока. Для изображений других наборов данных растений, таких как iNaturalist или beans, модель почти всегда должна возвращать значение unknown .
Параметры
DATASET = 'cassava'
DATASET_SPLIT = 'test'
BATCH_SIZE = 32
MAX_EXAMPLES = 1000
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88
Узнать больше
- Узнайте больше о модели на TensorFlow Hub: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- Узнайте, как создать собственный классификатор изображений, работающий на мобильном телефоне с помощью ML Kit с версией этой модели TensorFlow Lite .