![]() | ![]() | ![]() | ![]() | ![]() |
En este cuaderno colab, aprenderá cómo utilizar el fabricante de TensorFlow Lite Modelo para entrenar un modelo de clasificación de audio personalizado.
La biblioteca Model Maker usa el aprendizaje por transferencia para simplificar el proceso de entrenamiento de un modelo de TensorFlow Lite con un conjunto de datos personalizado. Reentrenar un modelo de TensorFlow Lite con su propio conjunto de datos personalizado reduce la cantidad de datos de entrenamiento y el tiempo requerido.
Es parte del experimento de código para Personalizar un modelo de audio y desplegar en Android .
Usarás un conjunto de datos de aves personalizado y exportarás un modelo TFLite que se puede usar en un teléfono, un modelo TensorFlow.JS que se puede usar para inferencias en el navegador y también una versión de SavedModel que puedes usar para servir.
Instalar dependencias
pip install tflite-model-maker
Importar TensorFlow, Model Maker y otras bibliotecas
Entre las dependencias necesarias, utilizará TensorFlow y Model Maker. Aparte de esos, los otros son para manipulación de audio, reproducción y visualizaciones.
import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import glob
import random
from IPython.display import Audio, Image
from scipy.io import wavfile
print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release PkgResourcesDeprecationWarning, /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/numba/core/errors.py:168: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9 warnings.warn(msg) TensorFlow Version: 2.6.1 Model Maker Version: 0.3.2
El conjunto de datos de aves
El conjunto de datos Birds es una colección educativa de 5 tipos de cantos de pájaros:
- Wren de pecho blanco
- Gorrión común
- Cruz Roja
- Antpitta coronado de castañas
- Cola espina de Azara
El audio original vino de Xeno-canto , que es un sitio web dedicado a compartir sonidos de aves de todo el mundo.
Comencemos descargando los datos.
birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',
'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',
cache_dir='./',
cache_subdir='dataset',
extract=True)
Downloading data from https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip 343687168/343680986 [==============================] - 2s 0us/step 343695360/343680986 [==============================] - 2s 0us/step
Explore los datos
Los audios ya están divididos en carpetas de entrenamiento y de prueba. Dentro de cada carpeta dividida, hay una carpeta para cada ave, usando su bird_code
como nombre.
Los audios son todos mono y con una frecuencia de muestreo de 16 kHz.
Para obtener más información acerca de cada archivo, se puede leer la metadata.csv
archivo. Contiene todos los autores de archivos, lincenses y algo más de información. No necesitará leerlo usted mismo en este tutorial.
# @title [Run this] Util functions and data structures.
data_dir = './dataset/small_birds_dataset'
bird_code_to_name = {
'wbwwre1': 'White-breasted Wood-Wren',
'houspa': 'House Sparrow',
'redcro': 'Red Crossbill',
'chcant2': 'Chestnut-crowned Antpitta',
'azaspi1': "Azara's Spinetail",
}
birds_images = {
'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', # Alejandro Bayer Tamayo from Armenia, Colombia
'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', # Diliff
'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', # Elaine R. Wilson, www.naturespicsonline.com
'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', # Mike's Birds from Riverside, CA, US
'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368
}
test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))
def get_random_audio_file():
test_list = glob.glob(test_files)
random_audio_path = random.choice(test_list)
return random_audio_path
def show_bird_data(audio_path):
sample_rate, audio_data = wavfile.read(audio_path, 'rb')
bird_code = audio_path.split('/')[-2]
print(f'Bird name: {bird_code_to_name[bird_code]}')
print(f'Bird code: {bird_code}')
display(Image(birds_images[bird_code]))
plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
plt.title(plttitle)
plt.plot(audio_data)
display(Audio(audio_data, rate=sample_rate))
print('functions and data structures created')
functions and data structures created
Reproduciendo algo de audio
Para comprender mejor los datos, escuchemos archivos de audio aleatorios de la división de prueba.
random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: Azara's Spinetail Bird code: azaspi1
Entrenando el modelo
Al usar Model Maker para audio, debe comenzar con una especificación de modelo. Este es el modelo base del que su nuevo modelo extraerá información para conocer las nuevas clases. También afecta la forma en que se transformará el conjunto de datos para respetar los parámetros de las especificaciones del modelo, como: frecuencia de muestreo, número de canales.
YAMNet es un evento de audio clasificador entrenado en la AudioSet conjunto de datos para predecir los eventos de audio desde la ontología AudioSet.
Se espera que su entrada sea de 16 kHz y con 1 canal.
No es necesario que vuelva a muestrear usted mismo. Model Maker se encarga de eso por ti.
frame_length
es decidir cuánto tiempo cada muestra es traininng. en este caso EXPECTED_WAVEFORM_LENGTH * 3 sframe_steps
es decidir hasta qué punto aparte son las muestras de entrenamiento. En este caso, la iésima muestra comenzará en EXPECTED_WAVEFORM_LENGTH * 6s después de la (i-1 )ésima muestra.
La razón para establecer estos valores es evitar algunas limitaciones en el conjunto de datos del mundo real.
Por ejemplo, en el conjunto de datos de aves, las aves no cantan todo el tiempo. Cantan, descansan y cantan de nuevo, con ruidos de por medio. Tener un marco largo ayudaría a capturar el canto, pero configurarlo demasiado reducirá la cantidad de muestras para el entrenamiento.
spec = audio_classifier.YamNetSpec(
keep_yamnet_and_custom_heads=True,
frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)
INFO:tensorflow:Checkpoints are stored in /tmp/tmp7180wsrw
Cargando los datos
Model Maker tiene la API para cargar los datos de una carpeta y tenerlos en el formato esperado para la especificación del modelo.
El tren y la división de prueba se basan en las carpetas. El conjunto de datos de validación se creará como el 20% de la división del tren.
train_data = audio_classifier.DataLoader.from_folder(
spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
spec, os.path.join(data_dir, 'test'), cache=True)
Entrenando el modelo
la audio_classifier tiene la create
método que crea un modelo y que ya se inicia la formación de la misma.
Puede personalizar muchos parámetros; para obtener más información, puede leer más detalles en la documentación.
En este primer intento, usarás todas las configuraciones predeterminadas y entrenarás durante 100 épocas.
batch_size = 128
epochs = 100
print('Training the model')
model = audio_classifier.create(
train_data,
spec,
validation_data,
batch_size=batch_size,
epochs=epochs)
Training the model Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= classification_head (Dense) (None, 5) 5125 ================================================================= Total params: 5,125 Trainable params: 5,125 Non-trainable params: 0 _________________________________________________________________ Epoch 1/100 23/23 [==============================] - 49s 2s/step - loss: 1.6125 - acc: 0.2433 - val_loss: 1.2951 - val_acc: 0.4908 Epoch 2/100 23/23 [==============================] - 1s 22ms/step - loss: 1.3413 - acc: 0.4557 - val_loss: 1.1354 - val_acc: 0.7138 Epoch 3/100 23/23 [==============================] - 1s 22ms/step - loss: 1.1689 - acc: 0.6013 - val_loss: 1.0066 - val_acc: 0.7571 Epoch 4/100 23/23 [==============================] - 1s 21ms/step - loss: 1.0543 - acc: 0.6552 - val_loss: 0.9160 - val_acc: 0.7837 Epoch 5/100 23/23 [==============================] - 0s 21ms/step - loss: 0.9651 - acc: 0.7052 - val_loss: 0.8558 - val_acc: 0.8020 Epoch 6/100 23/23 [==============================] - 1s 21ms/step - loss: 0.8970 - acc: 0.7174 - val_loss: 0.8080 - val_acc: 0.8070 Epoch 7/100 23/23 [==============================] - 0s 21ms/step - loss: 0.8532 - acc: 0.7261 - val_loss: 0.7701 - val_acc: 0.8136 Epoch 8/100 23/23 [==============================] - 0s 21ms/step - loss: 0.8034 - acc: 0.7501 - val_loss: 0.7439 - val_acc: 0.8186 Epoch 9/100 23/23 [==============================] - 0s 21ms/step - loss: 0.7700 - acc: 0.7595 - val_loss: 0.7234 - val_acc: 0.8170 Epoch 10/100 23/23 [==============================] - 0s 21ms/step - loss: 0.7318 - acc: 0.7769 - val_loss: 0.7011 - val_acc: 0.8220 Epoch 11/100 23/23 [==============================] - 0s 21ms/step - loss: 0.7104 - acc: 0.7744 - val_loss: 0.6860 - val_acc: 0.8170 Epoch 12/100 23/23 [==============================] - 1s 21ms/step - loss: 0.6866 - acc: 0.7855 - val_loss: 0.6704 - val_acc: 0.8186 Epoch 13/100 23/23 [==============================] - 1s 21ms/step - loss: 0.6556 - acc: 0.8008 - val_loss: 0.6608 - val_acc: 0.8170 Epoch 14/100 23/23 [==============================] - 0s 21ms/step - loss: 0.6414 - acc: 0.8008 - val_loss: 0.6503 - val_acc: 0.8220 Epoch 15/100 23/23 [==============================] - 0s 20ms/step - loss: 0.6263 - acc: 0.8040 - val_loss: 0.6414 - val_acc: 0.8186 Epoch 16/100 23/23 [==============================] - 1s 21ms/step - loss: 0.6033 - acc: 0.8154 - val_loss: 0.6329 - val_acc: 0.8186 Epoch 17/100 23/23 [==============================] - 1s 21ms/step - loss: 0.5963 - acc: 0.8123 - val_loss: 0.6289 - val_acc: 0.8186 Epoch 18/100 23/23 [==============================] - 0s 21ms/step - loss: 0.5828 - acc: 0.8172 - val_loss: 0.6238 - val_acc: 0.8220 Epoch 19/100 23/23 [==============================] - 0s 21ms/step - loss: 0.5665 - acc: 0.8273 - val_loss: 0.6200 - val_acc: 0.8220 Epoch 20/100 23/23 [==============================] - 0s 20ms/step - loss: 0.5523 - acc: 0.8297 - val_loss: 0.6109 - val_acc: 0.8186 Epoch 21/100 23/23 [==============================] - 0s 20ms/step - loss: 0.5522 - acc: 0.8200 - val_loss: 0.6076 - val_acc: 0.8253 Epoch 22/100 23/23 [==============================] - 1s 21ms/step - loss: 0.5363 - acc: 0.8352 - val_loss: 0.6013 - val_acc: 0.8186 Epoch 23/100 23/23 [==============================] - 0s 20ms/step - loss: 0.5273 - acc: 0.8412 - val_loss: 0.5968 - val_acc: 0.8136 Epoch 24/100 23/23 [==============================] - 1s 21ms/step - loss: 0.5172 - acc: 0.8339 - val_loss: 0.5954 - val_acc: 0.8153 Epoch 25/100 23/23 [==============================] - 1s 22ms/step - loss: 0.5123 - acc: 0.8429 - val_loss: 0.5902 - val_acc: 0.8153 Epoch 26/100 23/23 [==============================] - 1s 21ms/step - loss: 0.5066 - acc: 0.8415 - val_loss: 0.5906 - val_acc: 0.8153 Epoch 27/100 23/23 [==============================] - 1s 21ms/step - loss: 0.5015 - acc: 0.8373 - val_loss: 0.5833 - val_acc: 0.8136 Epoch 28/100 23/23 [==============================] - 0s 21ms/step - loss: 0.4879 - acc: 0.8432 - val_loss: 0.5832 - val_acc: 0.8103 Epoch 29/100 23/23 [==============================] - 0s 20ms/step - loss: 0.4840 - acc: 0.8537 - val_loss: 0.5767 - val_acc: 0.8186 Epoch 30/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4793 - acc: 0.8530 - val_loss: 0.5753 - val_acc: 0.8103 Epoch 31/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4718 - acc: 0.8554 - val_loss: 0.5758 - val_acc: 0.8103 Epoch 32/100 23/23 [==============================] - 0s 20ms/step - loss: 0.4649 - acc: 0.8554 - val_loss: 0.5706 - val_acc: 0.8103 Epoch 33/100 23/23 [==============================] - 0s 20ms/step - loss: 0.4565 - acc: 0.8554 - val_loss: 0.5689 - val_acc: 0.8120 Epoch 34/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4492 - acc: 0.8589 - val_loss: 0.5679 - val_acc: 0.8053 Epoch 35/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4467 - acc: 0.8606 - val_loss: 0.5680 - val_acc: 0.8087 Epoch 36/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4383 - acc: 0.8644 - val_loss: 0.5634 - val_acc: 0.8037 Epoch 37/100 23/23 [==============================] - 1s 20ms/step - loss: 0.4451 - acc: 0.8641 - val_loss: 0.5635 - val_acc: 0.8037 Epoch 38/100 23/23 [==============================] - 1s 22ms/step - loss: 0.4393 - acc: 0.8620 - val_loss: 0.5616 - val_acc: 0.8037 Epoch 39/100 23/23 [==============================] - 0s 21ms/step - loss: 0.4256 - acc: 0.8710 - val_loss: 0.5607 - val_acc: 0.8020 Epoch 40/100 23/23 [==============================] - 0s 21ms/step - loss: 0.4296 - acc: 0.8669 - val_loss: 0.5612 - val_acc: 0.8037 Epoch 41/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4196 - acc: 0.8742 - val_loss: 0.5590 - val_acc: 0.8020 Epoch 42/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4203 - acc: 0.8658 - val_loss: 0.5556 - val_acc: 0.8053 Epoch 43/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4124 - acc: 0.8697 - val_loss: 0.5585 - val_acc: 0.8053 Epoch 44/100 23/23 [==============================] - 0s 20ms/step - loss: 0.4110 - acc: 0.8735 - val_loss: 0.5552 - val_acc: 0.8020 Epoch 45/100 23/23 [==============================] - 0s 20ms/step - loss: 0.4065 - acc: 0.8683 - val_loss: 0.5535 - val_acc: 0.8020 Epoch 46/100 23/23 [==============================] - 0s 21ms/step - loss: 0.3998 - acc: 0.8787 - val_loss: 0.5526 - val_acc: 0.8003 Epoch 47/100 23/23 [==============================] - 1s 21ms/step - loss: 0.4038 - acc: 0.8700 - val_loss: 0.5546 - val_acc: 0.7970 Epoch 48/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3977 - acc: 0.8804 - val_loss: 0.5536 - val_acc: 0.7987 Epoch 49/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3948 - acc: 0.8797 - val_loss: 0.5490 - val_acc: 0.7970 Epoch 50/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3940 - acc: 0.8763 - val_loss: 0.5458 - val_acc: 0.7987 Epoch 51/100 23/23 [==============================] - 0s 21ms/step - loss: 0.3905 - acc: 0.8763 - val_loss: 0.5507 - val_acc: 0.7987 Epoch 52/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3802 - acc: 0.8808 - val_loss: 0.5480 - val_acc: 0.7920 Epoch 53/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3822 - acc: 0.8797 - val_loss: 0.5467 - val_acc: 0.8003 Epoch 54/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3822 - acc: 0.8825 - val_loss: 0.5473 - val_acc: 0.7937 Epoch 55/100 23/23 [==============================] - 0s 21ms/step - loss: 0.3826 - acc: 0.8783 - val_loss: 0.5440 - val_acc: 0.7953 Epoch 56/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3765 - acc: 0.8808 - val_loss: 0.5435 - val_acc: 0.7937 Epoch 57/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3805 - acc: 0.8839 - val_loss: 0.5466 - val_acc: 0.7953 Epoch 58/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3737 - acc: 0.8856 - val_loss: 0.5429 - val_acc: 0.7953 Epoch 59/100 23/23 [==============================] - 0s 21ms/step - loss: 0.3716 - acc: 0.8902 - val_loss: 0.5454 - val_acc: 0.7937 Epoch 60/100 23/23 [==============================] - 1s 20ms/step - loss: 0.3771 - acc: 0.8797 - val_loss: 0.5477 - val_acc: 0.7953 Epoch 61/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3555 - acc: 0.8926 - val_loss: 0.5444 - val_acc: 0.7953 Epoch 62/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3645 - acc: 0.8832 - val_loss: 0.5461 - val_acc: 0.7953 Epoch 63/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3595 - acc: 0.8902 - val_loss: 0.5407 - val_acc: 0.7937 Epoch 64/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3666 - acc: 0.8839 - val_loss: 0.5412 - val_acc: 0.7987 Epoch 65/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3548 - acc: 0.8905 - val_loss: 0.5450 - val_acc: 0.7970 Epoch 66/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3550 - acc: 0.8902 - val_loss: 0.5410 - val_acc: 0.7970 Epoch 67/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3438 - acc: 0.8919 - val_loss: 0.5416 - val_acc: 0.7987 Epoch 68/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3502 - acc: 0.8950 - val_loss: 0.5441 - val_acc: 0.7987 Epoch 69/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3484 - acc: 0.8895 - val_loss: 0.5423 - val_acc: 0.7970 Epoch 70/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3486 - acc: 0.8891 - val_loss: 0.5391 - val_acc: 0.7953 Epoch 71/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3412 - acc: 0.8957 - val_loss: 0.5396 - val_acc: 0.7937 Epoch 72/100 23/23 [==============================] - 0s 21ms/step - loss: 0.3377 - acc: 0.8992 - val_loss: 0.5394 - val_acc: 0.7937 Epoch 73/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3434 - acc: 0.8933 - val_loss: 0.5454 - val_acc: 0.7953 Epoch 74/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3430 - acc: 0.8933 - val_loss: 0.5420 - val_acc: 0.7953 Epoch 75/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3439 - acc: 0.8881 - val_loss: 0.5402 - val_acc: 0.7937 Epoch 76/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3357 - acc: 0.8964 - val_loss: 0.5400 - val_acc: 0.7920 Epoch 77/100 23/23 [==============================] - 0s 21ms/step - loss: 0.3382 - acc: 0.8940 - val_loss: 0.5432 - val_acc: 0.7903 Epoch 78/100 23/23 [==============================] - 0s 21ms/step - loss: 0.3355 - acc: 0.8950 - val_loss: 0.5440 - val_acc: 0.7920 Epoch 79/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3348 - acc: 0.8950 - val_loss: 0.5394 - val_acc: 0.7920 Epoch 80/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3308 - acc: 0.8964 - val_loss: 0.5406 - val_acc: 0.7903 Epoch 81/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3288 - acc: 0.8943 - val_loss: 0.5400 - val_acc: 0.7953 Epoch 82/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3290 - acc: 0.8999 - val_loss: 0.5392 - val_acc: 0.7953 Epoch 83/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3308 - acc: 0.8936 - val_loss: 0.5409 - val_acc: 0.7903 Epoch 84/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3316 - acc: 0.8947 - val_loss: 0.5359 - val_acc: 0.7920 Epoch 85/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3264 - acc: 0.8936 - val_loss: 0.5360 - val_acc: 0.7937 Epoch 86/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3202 - acc: 0.8950 - val_loss: 0.5399 - val_acc: 0.7903 Epoch 87/100 23/23 [==============================] - 1s 22ms/step - loss: 0.3272 - acc: 0.8982 - val_loss: 0.5382 - val_acc: 0.7920 Epoch 88/100 23/23 [==============================] - 1s 23ms/step - loss: 0.3207 - acc: 0.8985 - val_loss: 0.5405 - val_acc: 0.7920 Epoch 89/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3206 - acc: 0.8971 - val_loss: 0.5405 - val_acc: 0.7937 Epoch 90/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3253 - acc: 0.8975 - val_loss: 0.5347 - val_acc: 0.7937 Epoch 91/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3175 - acc: 0.8992 - val_loss: 0.5310 - val_acc: 0.7937 Epoch 92/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3288 - acc: 0.8929 - val_loss: 0.5338 - val_acc: 0.7937 Epoch 93/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3169 - acc: 0.9006 - val_loss: 0.5399 - val_acc: 0.7887 Epoch 94/100 23/23 [==============================] - 1s 22ms/step - loss: 0.3133 - acc: 0.8975 - val_loss: 0.5399 - val_acc: 0.7903 Epoch 95/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3089 - acc: 0.9075 - val_loss: 0.5369 - val_acc: 0.7903 Epoch 96/100 23/23 [==============================] - 0s 20ms/step - loss: 0.3056 - acc: 0.9002 - val_loss: 0.5347 - val_acc: 0.7937 Epoch 97/100 23/23 [==============================] - 0s 21ms/step - loss: 0.3130 - acc: 0.9034 - val_loss: 0.5382 - val_acc: 0.7920 Epoch 98/100 23/23 [==============================] - 1s 22ms/step - loss: 0.3098 - acc: 0.8964 - val_loss: 0.5374 - val_acc: 0.7920 Epoch 99/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3095 - acc: 0.9009 - val_loss: 0.5368 - val_acc: 0.7937 Epoch 100/100 23/23 [==============================] - 1s 21ms/step - loss: 0.3091 - acc: 0.9037 - val_loss: 0.5400 - val_acc: 0.7887
La precisión parece buena, pero es importante ejecutar el paso de evaluación en los datos de prueba y verificar que su modelo logró buenos resultados en datos sin semilla.
print('Evaluating the model')
model.evaluate(test_data)
Evaluating the model 28/28 [==============================] - 12s 404ms/step - loss: 0.6626 - acc: 0.7761 [0.6626318693161011, 0.7761194109916687]
Entendiendo su modelo
Cuando el entrenamiento de un clasificador, es útil para ver la matriz de confusión . La matriz de confusión le brinda un conocimiento detallado de cómo se está desempeñando su clasificador en los datos de prueba.
Model Maker ya crea la matriz de confusión para usted.
def show_confusion_matrix(confusion, test_labels):
"""Compute confusion matrix and normalize."""
confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
axis_labels = test_labels
ax = sns.heatmap(
confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
cmap='Blues', annot=True, fmt='.2f', square=True)
plt.title("Confusion matrix")
plt.ylabel("True label")
plt.xlabel("Predicted label")
confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)
Probando el modelo [Opcional]
Puede probar el modelo en una muestra de audio del conjunto de datos de prueba solo para ver los resultados.
Primero obtienes el modelo de servicio.
serving_model = model.create_serving_model()
print(f'Model\'s input shape and type: {serving_model.inputs}')
print(f'Model\'s output shape and type: {serving_model.outputs}')
Model's input shape and type: [<KerasTensor: shape=(None, 15600) dtype=float32 (created by layer 'audio')>] Model's output shape and type: [<KerasTensor: shape=(1, 521) dtype=float32 (created by layer 'keras_layer')>, <KerasTensor: shape=(1, 5) dtype=float32 (created by layer 'sequential')>]
Volviendo al audio aleatorio que cargó anteriormente
# if you want to try another file just uncoment the line below
random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: Red Crossbill Bird code: redcro
El modelo creado tiene una ventana de entrada fija.
Para un archivo de audio determinado, tendrá que dividirlo en ventanas de datos del tamaño esperado. Es posible que la última ventana deba llenarse con ceros.
sample_rate, audio_data = wavfile.read(random_audio, 'rb')
audio_data = np.array(audio_data) / tf.int16.max
input_size = serving_model.input_shape[1]
splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)
print(f'Test audio path: {random_audio}')
print(f'Original size of the audio data: {len(audio_data)}')
print(f'Number of windows for inference: {len(splitted_audio_data)}')
Test audio path: /tmpfs/src/temp/tensorflow/lite/g3doc/tutorials/dataset/small_birds_dataset/test/redcro/XC64752.wav Original size of the audio data: 1210848 Number of windows for inference: 78
Hará un bucle sobre todo el audio dividido y aplicará el modelo para cada uno de ellos.
El modelo que acaba de entrenar tiene 2 salidas: la salida de YAMNet original y la que acaba de entrenar. Esto es importante porque el entorno del mundo real es más complicado que los simples sonidos de los pájaros. Puede usar la salida de YAMNet para filtrar el audio no relevante, por ejemplo, en el caso de uso de aves, si YAMNet no está clasificando Aves o Animales, esto podría mostrar que la salida de su modelo podría tener una clasificación irrelevante.
A continuación, se imprimen ambos outpus para facilitar la comprensión de su relación. La mayoría de los errores que comete su modelo se producen cuando la predicción de YAMNet no está relacionada con su dominio (por ejemplo: pájaros).
print(random_audio)
results = []
print('Result of the window ith: your model class -> score, (spec class -> score)')
for i, data in enumerate(splitted_audio_data):
yamnet_output, inference = serving_model(data)
results.append(inference[0].numpy())
result_index = tf.argmax(inference[0])
spec_result_index = tf.argmax(yamnet_output[0])
t = spec._yamnet_labels()[spec_result_index]
result_str = f'Result of the window {i}: ' \
f'\t{test_data.index_to_label[result_index]} -> {inference[0][result_index].numpy():.3f}, ' \
f'\t({spec._yamnet_labels()[spec_result_index]} -> {yamnet_output[0][spec_result_index]:.3f})'
print(result_str)
results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'Mean result: {test_data.index_to_label[result_index]} -> {mean_results[result_index]}')
/tmpfs/src/temp/tensorflow/lite/g3doc/tutorials/dataset/small_birds_dataset/test/redcro/XC64752.wav Result of the window ith: your model class -> score, (spec class -> score) Result of the window 0: redcro -> 0.772, (Wild animals -> 0.997) Result of the window 1: redcro -> 0.912, (Wild animals -> 0.794) Result of the window 2: redcro -> 0.679, (Environmental noise -> 0.545) Result of the window 3: redcro -> 0.910, (Wild animals -> 0.975) Result of the window 4: redcro -> 0.863, (Animal -> 0.911) Result of the window 5: redcro -> 0.794, (Animal -> 0.757) Result of the window 6: redcro -> 0.953, (Animal -> 0.929) Result of the window 7: redcro -> 0.887, (Wild animals -> 0.837) Result of the window 8: redcro -> 0.905, (Wild animals -> 0.925) Result of the window 9: houspa -> 0.568, (Animal -> 0.777) Result of the window 10: redcro -> 0.724, (Bird -> 0.997) Result of the window 11: houspa -> 0.585, (Animal -> 0.954) Result of the window 12: azaspi1 -> 0.621, (Animal -> 0.849) Result of the window 13: redcro -> 0.873, (Wild animals -> 0.888) Result of the window 14: redcro -> 0.940, (Bird -> 0.869) Result of the window 15: redcro -> 0.827, (Animal -> 0.773) Result of the window 16: redcro -> 0.596, (Animal -> 0.732) Result of the window 17: redcro -> 0.928, (Animal -> 0.909) Result of the window 18: redcro -> 0.791, (Animal -> 0.742) Result of the window 19: redcro -> 0.874, (Animal -> 0.906) Result of the window 20: houspa -> 0.487, (Animal -> 0.490) Result of the window 21: redcro -> 0.991, (Animal -> 0.959) Result of the window 22: redcro -> 0.691, (Animal -> 0.710) Result of the window 23: chcant2 -> 0.996, (Water -> 0.601) Result of the window 24: chcant2 -> 0.516, (Outside, rural or natural -> 0.209) Result of the window 25: chcant2 -> 0.888, (Stream -> 0.690) Result of the window 26: azaspi1 -> 0.691, (Animal -> 0.677) Result of the window 27: redcro -> 0.996, (Animal -> 0.933) Result of the window 28: redcro -> 0.921, (Bird vocalization, bird call, bird song -> 0.784) Result of the window 29: redcro -> 0.775, (Animal -> 0.857) Result of the window 30: redcro -> 0.987, (Animal -> 0.977) Result of the window 31: chcant2 -> 0.744, (Insect -> 0.543) Result of the window 32: chcant2 -> 0.586, (Environmental noise -> 0.429) Result of the window 33: chcant2 -> 0.704, (Outside, rural or natural -> 0.406) Result of the window 34: chcant2 -> 0.688, (Environmental noise -> 0.780) Result of the window 35: redcro -> 0.505, (Environmental noise -> 0.574) Result of the window 36: chcant2 -> 0.908, (Animal -> 0.375) Result of the window 37: chcant2 -> 0.812, (Outside, rural or natural -> 0.392) Result of the window 38: redcro -> 0.933, (Animal -> 0.938) Result of the window 39: redcro -> 0.744, (Wild animals -> 0.868) Result of the window 40: redcro -> 0.664, (Wild animals -> 0.954) Result of the window 41: redcro -> 0.548, (Animal -> 0.905) Result of the window 42: redcro -> 0.746, (Animal -> 0.948) Result of the window 43: redcro -> 0.970, (Animal -> 0.989) Result of the window 44: redcro -> 0.827, (Animal -> 0.857) Result of the window 45: redcro -> 0.911, (Animal -> 0.978) Result of the window 46: redcro -> 0.983, (Animal -> 0.982) Result of the window 47: chcant2 -> 0.701, (Outside, rural or natural -> 0.357) Result of the window 48: redcro -> 0.879, (Animal -> 0.948) Result of the window 49: redcro -> 0.968, (Animal -> 0.983) Result of the window 50: redcro -> 0.975, (Bird vocalization, bird call, bird song -> 0.752) Result of the window 51: redcro -> 0.814, (Animal -> 0.818) Result of the window 52: chcant2 -> 0.398, (Environmental noise -> 0.657) Result of the window 53: chcant2 -> 0.676, (Outside, rural or natural -> 0.335) Result of the window 54: chcant2 -> 0.716, (White noise -> 0.358) Result of the window 55: chcant2 -> 0.565, (Outside, rural or natural -> 0.380) Result of the window 56: wbwwre1 -> 0.795, (Animal -> 0.922) Result of the window 57: chcant2 -> 0.857, (Environmental noise -> 0.328) Result of the window 58: chcant2 -> 0.955, (Outside, rural or natural -> 0.299) Result of the window 59: chcant2 -> 0.968, (Rustle -> 0.258) Result of the window 60: chcant2 -> 0.948, (Outside, rural or natural -> 0.192) Result of the window 61: chcant2 -> 0.563, (Animal -> 0.357) Result of the window 62: houspa -> 0.603, (Wild animals -> 0.802) Result of the window 63: chcant2 -> 0.797, (Insect -> 0.575) Result of the window 64: redcro -> 0.811, (Wild animals -> 0.978) Result of the window 65: chcant2 -> 0.750, (Environmental noise -> 0.507) Result of the window 66: houspa -> 0.519, (Animal -> 0.902) Result of the window 67: redcro -> 0.998, (Animal -> 0.988) Result of the window 68: houspa -> 0.841, (Animal -> 0.997) Result of the window 69: redcro -> 0.901, (Animal -> 0.997) Result of the window 70: houspa -> 0.942, (Animal -> 0.964) Result of the window 71: redcro -> 0.912, (Animal -> 0.983) Result of the window 72: redcro -> 0.912, (Animal -> 0.762) Result of the window 73: houspa -> 0.638, (Animal -> 0.916) Result of the window 74: redcro -> 0.730, (Wild animals -> 0.762) Result of the window 75: redcro -> 0.969, (Wild animals -> 0.880) Result of the window 76: chcant2 -> 0.471, (Wild animals -> 0.555) Result of the window 77: chcant2 -> 0.793, (Outside, rural or natural -> 0.366) Mean result: redcro -> 0.5561891794204712
Exportando el modelo
El último paso es exportar su modelo para usarlo en dispositivos integrados o en el navegador.
La export
método de exportación ambos formatos para usted.
models_path = './birds_models'
print(f'Exporing the TFLite model to {models_path}')
model.export(models_path, tflite_filename='my_birds_model.tflite')
Exporing the TFLite model to ./birds_models WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. 2021-11-02 12:50:55.630878: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/tmpy4w2awkd/assets INFO:tensorflow:Assets written to: /tmp/tmpy4w2awkd/assets 2021-11-02 12:51:00.841619: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-11-02 12:51:00.841671: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite
También puede exportar la versión SavedModel para servirla o usarla en un entorno Python.
model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets INFO:tensorflow:Saving labels in ./birds_models/labels.txt INFO:tensorflow:Saving labels in ./birds_models/labels.txt
Próximos pasos
Lo hiciste.
Ahora su nuevo modelo se puede implementar en dispositivos móviles que utilizan la API TFLite AudioClassifier de tareas .
También puede probar el mismo proceso con sus propios datos con diferentes clases y aquí está la documentación de modelo Maker para Clasificación de audio .
También aprender de aplicaciones de referencia de extremo a extremo: Android , iOS .