¡El Día de la Comunidad de ML es el 9 de noviembre! Únase a nosotros para recibir actualizaciones de TensorFlow, JAX, y más Más información

Transferir aprendizaje para el dominio de audio con TensorFlow Lite Model Maker

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno Ver modelo TF Hub

En este cuaderno colab, aprenderá cómo utilizar el fabricante de TensorFlow Lite Modelo para entrenar un modelo de clasificación de audio personalizado.

La biblioteca Model Maker usa el aprendizaje por transferencia para simplificar el proceso de entrenamiento de un modelo de TensorFlow Lite con un conjunto de datos personalizado. Volver a entrenar un modelo de TensorFlow Lite con su propio conjunto de datos personalizado reduce la cantidad de datos de entrenamiento y el tiempo requerido.

Es parte del experimento de código para Personalizar un modelo de audio y desplegar en Android .

Usarás un conjunto de datos de aves personalizado y exportarás un modelo TFLite que se puede usar en un teléfono, un modelo TensorFlow.JS que se puede usar para inferencias en el navegador y también una versión de SavedModel que puedes usar para servir.

Instalar dependencias

 pip install tflite-model-maker

Importar TensorFlow, Model Maker y otras bibliotecas

Entre las dependencias que se necesitan, usará TensorFlow y Model Maker. Aparte de esos, los otros son para la manipulación de audio, reproducción y visualizaciones.

import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import itertools
import glob
import random

from IPython.display import Audio, Image
from scipy.io import wavfile

print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/numba/core/errors.py:168: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
  warnings.warn(msg)
TensorFlow Version: 2.6.0
Model Maker Version: 0.3.2

El conjunto de datos de Birds

El conjunto de datos Birds es una colección educativa de 5 tipos de cantos de pájaros:

  • Wren de pecho blanco
  • Gorrión común
  • Cruz Roja
  • Antpitta coronado de castaño
  • Cola espina de Azara

El audio original vino de Xeno-canto , que es un sitio web dedicado a compartir sonidos de aves de todo el mundo.

Comencemos descargando los datos.

birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',
                                                'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',
                                                cache_dir='./',
                                                cache_subdir='dataset',
                                                extract=True)
Downloading data from https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip
343687168/343680986 [==============================] - 6s 0us/step
343695360/343680986 [==============================] - 6s 0us/step

Explore los datos

Los audios ya están divididos en carpetas de entrenamiento y de prueba. Dentro de cada carpeta dividida, hay una carpeta para cada ave, usando su bird_code como nombre.

Los audios son todos mono y con frecuencia de muestreo de 16 kHz.

Para obtener más información acerca de cada archivo, se puede leer la metadata.csv archivo. Contiene todos los autores de archivos, lincenses y algo más de información. No necesitará leerlo usted mismo en este tutorial.

# @title [Run this] Util functions and data structures.

data_dir = './dataset/small_birds_dataset'

bird_code_to_name = {
  'wbwwre1': 'White-breasted Wood-Wren',
  'houspa': 'House Sparrow',
  'redcro': 'Red Crossbill',  
  'chcant2': 'Chestnut-crowned Antpitta',
  'azaspi1': "Azara's Spinetail",   
}

birds_images = {
  'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', #   Alejandro Bayer Tamayo from Armenia, Colombia 
  'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', #    Diliff
  'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', #  Elaine R. Wilson, www.naturespicsonline.com
  'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', #   Mike's Birds from Riverside, CA, US
  'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368
}

test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))

def get_random_audio_file():
  test_list = glob.glob(test_files)
  random_audio_path = random.choice(test_list)
  return random_audio_path


def show_bird_data(audio_path):
  sample_rate, audio_data = wavfile.read(audio_path, 'rb')

  bird_code = audio_path.split('/')[-2]
  print(f'Bird name: {bird_code_to_name[bird_code]}')
  print(f'Bird code: {bird_code}')
  display(Image(birds_images[bird_code]))

  plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
  plt.title(plttitle)
  plt.plot(audio_data)
  display(Audio(audio_data, rate=sample_rate))

print('functions and data structures created')
functions and data structures created

Reproduciendo algo de audio

Para comprender mejor los datos, escuchemos archivos de audio aleatorios de la división de prueba.

random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: White-breasted Wood-Wren
Bird code: wbwwre1

jpeg

png

Entrenando el modelo

Al usar Model Maker para audio, debe comenzar con una especificación de modelo. Este es el modelo base del que su nuevo modelo extraerá información para conocer las nuevas clases. También afecta la forma en que se transformará el conjunto de datos para respetar los parámetros de las especificaciones del modelo, como: frecuencia de muestreo, número de canales.

YAMNet es un evento de audio clasificador entrenado en la AudioSet conjunto de datos para predecir los eventos de audio desde la ontología AudioSet.

Se espera que su entrada sea de 16 kHz y con 1 canal.

No es necesario que vuelva a muestrear usted mismo. Model Maker se encarga de eso por ti.

  • frame_length es decidir cuánto tiempo cada muestra es traininng. en este caso EXPECTED_WAVEFORM_LENGTH * 3 s

  • frame_steps es decidir hasta qué punto aparte son las muestras de entrenamiento. En este caso, la i-ésima muestra comenzará en EXPECTED_WAVEFORM_LENGTH * 6 s después de la (i-1 )ésima muestra.

La razón para establecer estos valores es evitar algunas limitaciones en el conjunto de datos del mundo real.

Por ejemplo, en el conjunto de datos de aves, las aves no cantan todo el tiempo. Cantan, descansan y vuelven a cantar, con ruidos de por medio. Tener un marco largo ayudaría a capturar el canto, pero establecerlo demasiado largo reducirá la cantidad de muestras para el entrenamiento.

spec = audio_classifier.YamNetSpec(
    keep_yamnet_and_custom_heads=True,
    frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
    frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)
INFO:tensorflow:Checkpoints are stored in /tmp/tmp3s72bspo

Cargando los datos

Model Maker tiene la API para cargar los datos de una carpeta y tenerlos en el formato esperado para la especificación del modelo.

El tren y la división de prueba se basan en las carpetas. El conjunto de datos de validación se creará como el 20% de la división del tren.

train_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'test'), cache=True)

Entrenando el modelo

la audio_classifier tiene la create método que crea un modelo y que ya se inicia la formación de la misma.

Puede personalizar muchos parámetros; para obtener más información, puede leer más detalles en la documentación.

En este primer intento, usarás todas las configuraciones predeterminadas y entrenarás durante 100 épocas.

batch_size = 128
epochs = 100

print('Training the model')
model = audio_classifier.create(
    train_data,
    spec,
    validation_data,
    batch_size=batch_size,
    epochs=epochs)
Training the model
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
classification_head (Dense)  (None, 5)                 5125      
=================================================================
Total params: 5,125
Trainable params: 5,125
Non-trainable params: 0
_________________________________________________________________
Epoch 1/100
22/22 [==============================] - 54s 2s/step - loss: 1.5431 - acc: 0.3074 - val_loss: 1.3118 - val_acc: 0.4735
Epoch 2/100
22/22 [==============================] - 1s 25ms/step - loss: 1.3122 - acc: 0.4780 - val_loss: 1.1328 - val_acc: 0.5794
Epoch 3/100
22/22 [==============================] - 1s 25ms/step - loss: 1.1547 - acc: 0.5975 - val_loss: 1.0275 - val_acc: 0.6374
Epoch 4/100
22/22 [==============================] - 1s 24ms/step - loss: 1.0504 - acc: 0.6419 - val_loss: 0.9516 - val_acc: 0.6890
Epoch 5/100
22/22 [==============================] - 1s 24ms/step - loss: 0.9666 - acc: 0.6944 - val_loss: 0.8896 - val_acc: 0.7303
Epoch 6/100
22/22 [==============================] - 1s 24ms/step - loss: 0.9079 - acc: 0.7159 - val_loss: 0.8458 - val_acc: 0.7432
Epoch 7/100
22/22 [==============================] - 1s 25ms/step - loss: 0.8554 - acc: 0.7395 - val_loss: 0.8056 - val_acc: 0.7639
Epoch 8/100
22/22 [==============================] - 1s 25ms/step - loss: 0.8174 - acc: 0.7525 - val_loss: 0.7809 - val_acc: 0.7665
Epoch 9/100
22/22 [==============================] - 1s 24ms/step - loss: 0.7840 - acc: 0.7555 - val_loss: 0.7554 - val_acc: 0.7781
Epoch 10/100
22/22 [==============================] - 1s 24ms/step - loss: 0.7447 - acc: 0.7721 - val_loss: 0.7342 - val_acc: 0.7794
Epoch 11/100
22/22 [==============================] - 1s 25ms/step - loss: 0.7215 - acc: 0.7773 - val_loss: 0.7145 - val_acc: 0.7871
Epoch 12/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6953 - acc: 0.7932 - val_loss: 0.7034 - val_acc: 0.7897
Epoch 13/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6811 - acc: 0.7888 - val_loss: 0.6884 - val_acc: 0.7961
Epoch 14/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6605 - acc: 0.7987 - val_loss: 0.6744 - val_acc: 0.7974
Epoch 15/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6435 - acc: 0.7999 - val_loss: 0.6638 - val_acc: 0.7974
Epoch 16/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6244 - acc: 0.8106 - val_loss: 0.6524 - val_acc: 0.7987
Epoch 17/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6080 - acc: 0.8232 - val_loss: 0.6443 - val_acc: 0.7987
Epoch 18/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5917 - acc: 0.8202 - val_loss: 0.6336 - val_acc: 0.8013
Epoch 19/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5825 - acc: 0.8187 - val_loss: 0.6272 - val_acc: 0.8000
Epoch 20/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5776 - acc: 0.8287 - val_loss: 0.6182 - val_acc: 0.8052
Epoch 21/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5671 - acc: 0.8328 - val_loss: 0.6142 - val_acc: 0.8052
Epoch 22/100
22/22 [==============================] - 1s 23ms/step - loss: 0.5542 - acc: 0.8346 - val_loss: 0.6131 - val_acc: 0.8026
Epoch 23/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5474 - acc: 0.8372 - val_loss: 0.6074 - val_acc: 0.8039
Epoch 24/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5343 - acc: 0.8435 - val_loss: 0.6005 - val_acc: 0.8090
Epoch 25/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5324 - acc: 0.8376 - val_loss: 0.5926 - val_acc: 0.8103
Epoch 26/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5225 - acc: 0.8428 - val_loss: 0.5878 - val_acc: 0.8103
Epoch 27/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5215 - acc: 0.8431 - val_loss: 0.5848 - val_acc: 0.8116
Epoch 28/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5120 - acc: 0.8420 - val_loss: 0.5819 - val_acc: 0.8116
Epoch 29/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5023 - acc: 0.8465 - val_loss: 0.5796 - val_acc: 0.8155
Epoch 30/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4998 - acc: 0.8550 - val_loss: 0.5741 - val_acc: 0.8155
Epoch 31/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4994 - acc: 0.8509 - val_loss: 0.5722 - val_acc: 0.8142
Epoch 32/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4880 - acc: 0.8505 - val_loss: 0.5688 - val_acc: 0.8142
Epoch 33/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4851 - acc: 0.8513 - val_loss: 0.5615 - val_acc: 0.8181
Epoch 34/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4847 - acc: 0.8472 - val_loss: 0.5584 - val_acc: 0.8181
Epoch 35/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4721 - acc: 0.8609 - val_loss: 0.5584 - val_acc: 0.8194
Epoch 36/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4668 - acc: 0.8524 - val_loss: 0.5575 - val_acc: 0.8219
Epoch 37/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4682 - acc: 0.8531 - val_loss: 0.5530 - val_acc: 0.8232
Epoch 38/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4615 - acc: 0.8576 - val_loss: 0.5504 - val_acc: 0.8232
Epoch 39/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4636 - acc: 0.8535 - val_loss: 0.5526 - val_acc: 0.8219
Epoch 40/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4585 - acc: 0.8561 - val_loss: 0.5500 - val_acc: 0.8206
Epoch 41/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4480 - acc: 0.8613 - val_loss: 0.5412 - val_acc: 0.8206
Epoch 42/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4573 - acc: 0.8587 - val_loss: 0.5411 - val_acc: 0.8194
Epoch 43/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4523 - acc: 0.8565 - val_loss: 0.5362 - val_acc: 0.8194
Epoch 44/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4367 - acc: 0.8635 - val_loss: 0.5403 - val_acc: 0.8206
Epoch 45/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4318 - acc: 0.8631 - val_loss: 0.5366 - val_acc: 0.8194
Epoch 46/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4263 - acc: 0.8679 - val_loss: 0.5349 - val_acc: 0.8194
Epoch 47/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4430 - acc: 0.8572 - val_loss: 0.5362 - val_acc: 0.8194
Epoch 48/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4332 - acc: 0.8646 - val_loss: 0.5281 - val_acc: 0.8219
Epoch 49/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4229 - acc: 0.8687 - val_loss: 0.5304 - val_acc: 0.8206
Epoch 50/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4287 - acc: 0.8627 - val_loss: 0.5257 - val_acc: 0.8206
Epoch 51/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4119 - acc: 0.8824 - val_loss: 0.5283 - val_acc: 0.8181
Epoch 52/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4271 - acc: 0.8653 - val_loss: 0.5233 - val_acc: 0.8206
Epoch 53/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4050 - acc: 0.8676 - val_loss: 0.5226 - val_acc: 0.8206
Epoch 54/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4003 - acc: 0.8768 - val_loss: 0.5142 - val_acc: 0.8245
Epoch 55/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4087 - acc: 0.8775 - val_loss: 0.5170 - val_acc: 0.8206
Epoch 56/100
22/22 [==============================] - 1s 26ms/step - loss: 0.4029 - acc: 0.8731 - val_loss: 0.5187 - val_acc: 0.8194
Epoch 57/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4063 - acc: 0.8672 - val_loss: 0.5136 - val_acc: 0.8219
Epoch 58/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3998 - acc: 0.8787 - val_loss: 0.5132 - val_acc: 0.8245
Epoch 59/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3911 - acc: 0.8798 - val_loss: 0.5113 - val_acc: 0.8245
Epoch 60/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4022 - acc: 0.8742 - val_loss: 0.5044 - val_acc: 0.8284
Epoch 61/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3910 - acc: 0.8757 - val_loss: 0.5090 - val_acc: 0.8245
Epoch 62/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3869 - acc: 0.8798 - val_loss: 0.5129 - val_acc: 0.8245
Epoch 63/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3880 - acc: 0.8809 - val_loss: 0.5070 - val_acc: 0.8271
Epoch 64/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3905 - acc: 0.8757 - val_loss: 0.5047 - val_acc: 0.8284
Epoch 65/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3886 - acc: 0.8801 - val_loss: 0.5094 - val_acc: 0.8271
Epoch 66/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3877 - acc: 0.8746 - val_loss: 0.5037 - val_acc: 0.8284
Epoch 67/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3846 - acc: 0.8775 - val_loss: 0.5022 - val_acc: 0.8258
Epoch 68/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3890 - acc: 0.8831 - val_loss: 0.5015 - val_acc: 0.8284
Epoch 69/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3861 - acc: 0.8731 - val_loss: 0.5001 - val_acc: 0.8284
Epoch 70/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3802 - acc: 0.8812 - val_loss: 0.5001 - val_acc: 0.8284
Epoch 71/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3760 - acc: 0.8783 - val_loss: 0.4961 - val_acc: 0.8297
Epoch 72/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3755 - acc: 0.8742 - val_loss: 0.4937 - val_acc: 0.8310
Epoch 73/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3774 - acc: 0.8798 - val_loss: 0.5006 - val_acc: 0.8310
Epoch 74/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3756 - acc: 0.8787 - val_loss: 0.4955 - val_acc: 0.8297
Epoch 75/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3780 - acc: 0.8724 - val_loss: 0.4950 - val_acc: 0.8297
Epoch 76/100
22/22 [==============================] - 1s 26ms/step - loss: 0.3794 - acc: 0.8746 - val_loss: 0.4930 - val_acc: 0.8297
Epoch 77/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3614 - acc: 0.8809 - val_loss: 0.4915 - val_acc: 0.8323
Epoch 78/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3644 - acc: 0.8783 - val_loss: 0.4911 - val_acc: 0.8297
Epoch 79/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3680 - acc: 0.8787 - val_loss: 0.4871 - val_acc: 0.8335
Epoch 80/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3637 - acc: 0.8805 - val_loss: 0.4851 - val_acc: 0.8310
Epoch 81/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3647 - acc: 0.8875 - val_loss: 0.4866 - val_acc: 0.8335
Epoch 82/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3602 - acc: 0.8827 - val_loss: 0.4855 - val_acc: 0.8335
Epoch 83/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3544 - acc: 0.8868 - val_loss: 0.4849 - val_acc: 0.8335
Epoch 84/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3616 - acc: 0.8861 - val_loss: 0.4843 - val_acc: 0.8297
Epoch 85/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3642 - acc: 0.8820 - val_loss: 0.4801 - val_acc: 0.8310
Epoch 86/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3660 - acc: 0.8787 - val_loss: 0.4783 - val_acc: 0.8310
Epoch 87/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3470 - acc: 0.8868 - val_loss: 0.4863 - val_acc: 0.8348
Epoch 88/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3525 - acc: 0.8872 - val_loss: 0.4812 - val_acc: 0.8297
Epoch 89/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3622 - acc: 0.8768 - val_loss: 0.4858 - val_acc: 0.8297
Epoch 90/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3623 - acc: 0.8764 - val_loss: 0.4811 - val_acc: 0.8323
Epoch 91/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3501 - acc: 0.8857 - val_loss: 0.4820 - val_acc: 0.8310
Epoch 92/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3457 - acc: 0.8920 - val_loss: 0.4803 - val_acc: 0.8335
Epoch 93/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3566 - acc: 0.8831 - val_loss: 0.4810 - val_acc: 0.8310
Epoch 94/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3542 - acc: 0.8846 - val_loss: 0.4808 - val_acc: 0.8323
Epoch 95/100
22/22 [==============================] - 1s 26ms/step - loss: 0.3480 - acc: 0.8883 - val_loss: 0.4757 - val_acc: 0.8335
Epoch 96/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3441 - acc: 0.8883 - val_loss: 0.4810 - val_acc: 0.8348
Epoch 97/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3539 - acc: 0.8846 - val_loss: 0.4742 - val_acc: 0.8361
Epoch 98/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3461 - acc: 0.8864 - val_loss: 0.4741 - val_acc: 0.8310
Epoch 99/100
22/22 [==============================] - 1s 26ms/step - loss: 0.3323 - acc: 0.8964 - val_loss: 0.4802 - val_acc: 0.8348
Epoch 100/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3497 - acc: 0.8812 - val_loss: 0.4761 - val_acc: 0.8323

La precisión parece buena, pero es importante ejecutar el paso de evaluación en los datos de prueba y verificar que su modelo logró buenos resultados en datos sin semilla.

print('Evaluating the model')
model.evaluate(test_data)
Evaluating the model
28/28 [==============================] - 14s 469ms/step - loss: 0.7833 - acc: 0.7750
[0.7832977175712585, 0.7749713063240051]

Entendiendo su modelo

Cuando el entrenamiento de un clasificador, es útil para ver la matriz de confusión . La matriz de confusión le brinda un conocimiento detallado de cómo se está desempeñando su clasificador en los datos de prueba.

Model Maker ya crea la matriz de confusión para usted.

def show_confusion_matrix(confusion, test_labels):
  """Compute confusion matrix and normalize."""
  confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
  axis_labels = test_labels
  ax = sns.heatmap(
      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
      cmap='Blues', annot=True, fmt='.2f', square=True)
  plt.title("Confusion matrix")
  plt.ylabel("True label")
  plt.xlabel("Predicted label")

confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)

png

Probando el modelo [Opcional]

Puede probar el modelo en una muestra de audio del conjunto de datos de prueba solo para ver los resultados.

Primero obtienes el modelo de servicio.

serving_model = model.create_serving_model()

print(f'Model\'s input shape and type: {serving_model.inputs}')
print(f'Model\'s output shape and type: {serving_model.outputs}')
Model's input shape and type: [<KerasTensor: shape=(None, 15600) dtype=float32 (created by layer 'audio')>]
Model's output shape and type: [<KerasTensor: shape=(1, 521) dtype=float32 (created by layer 'keras_layer')>, <KerasTensor: shape=(1, 5) dtype=float32 (created by layer 'sequential')>]

Volviendo al audio aleatorio que cargó anteriormente

# if you want to try another file just uncoment the line below
random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: House Sparrow
Bird code: houspa

jpeg

png

El modelo creado tiene una ventana de entrada fija.

Para un archivo de audio determinado, tendrá que dividirlo en ventanas de datos del tamaño esperado. Es posible que la última ventana deba llenarse con ceros.

sample_rate, audio_data = wavfile.read(random_audio, 'rb')

audio_data = np.array(audio_data) / tf.int16.max
input_size = serving_model.input_shape[1]

splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)

print(f'Test audio path: {random_audio}')
print(f'Original size of the audio data: {len(audio_data)}')
print(f'Number of windows for inference: {len(splitted_audio_data)}')
Test audio path: /tmpfs/src/temp/tensorflow/lite/g3doc/tutorials/dataset/small_birds_dataset/test/houspa/XC565683.wav
Original size of the audio data: 520992
Number of windows for inference: 34

Hará un bucle sobre todo el audio dividido y aplicará el modelo para cada uno de ellos.

El modelo que acaba de entrenar tiene 2 salidas: la salida de YAMNet original y la que acaba de entrenar. Esto es importante porque el entorno del mundo real es más complicado que los simples sonidos de los pájaros. Puede utilizar la salida de YAMNet para filtrar el audio no relevante, por ejemplo, en el caso de uso de aves, si YAMNet no está clasificando Aves o Animales, esto podría mostrar que la salida de su modelo podría tener una clasificación irrelevante.

A continuación, se imprimen ambos outpus para facilitar la comprensión de su relación. La mayoría de los errores que comete su modelo se producen cuando la predicción de YAMNet no está relacionada con su dominio (por ejemplo: pájaros).

print(random_audio)

results = []
print('Result of the window ith:  your model class -> score,  (spec class -> score)')
for i, data in enumerate(splitted_audio_data):
  yamnet_output, inference = serving_model(data)
  results.append(inference[0].numpy())
  result_index = tf.argmax(inference[0])
  spec_result_index = tf.argmax(yamnet_output[0])
  t = spec._yamnet_labels()[spec_result_index]
  result_str = f'Result of the window {i}: ' \
  f'\t{test_data.index_to_label[result_index]} -> {inference[0][result_index].numpy():.3f}, ' \
  f'\t({spec._yamnet_labels()[spec_result_index]} -> {yamnet_output[0][spec_result_index]:.3f})'
  print(result_str)


results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'Mean result: {test_data.index_to_label[result_index]} -> {mean_results[result_index]}')
/tmpfs/src/temp/tensorflow/lite/g3doc/tutorials/dataset/small_birds_dataset/test/houspa/XC565683.wav
Result of the window ith:  your model class -> score,  (spec class -> score)
Result of the window 0:   redcro -> 0.941,     (Silence -> 1.000)
Result of the window 1:   redcro -> 0.508,     (Animal -> 0.940)
Result of the window 2:   redcro -> 0.994,     (Silence -> 1.000)
Result of the window 3:   redcro -> 0.876,     (Silence -> 1.000)
Result of the window 4:   houspa -> 0.958,     (Bird -> 0.974)
Result of the window 5:   redcro -> 0.991,     (Silence -> 1.000)
Result of the window 6:   redcro -> 0.990,     (Silence -> 1.000)
Result of the window 7:   redcro -> 0.424,     (Bird -> 0.944)
Result of the window 8:   houspa -> 0.744,     (Bird -> 0.974)
Result of the window 9:   chcant2 -> 0.998,    (Silence -> 1.000)
Result of the window 10:  chcant2 -> 0.821,    (Silence -> 1.000)
Result of the window 11:  redcro -> 0.860,     (Silence -> 1.000)
Result of the window 12:  redcro -> 0.569,     (Bird -> 0.970)
Result of the window 13:  redcro -> 0.973,     (Silence -> 1.000)
Result of the window 14:  redcro -> 0.973,     (Wild animals -> 0.930)
Result of the window 15:  chcant2 -> 0.532,    (Silence -> 1.000)
Result of the window 16:  redcro -> 0.960,     (Silence -> 1.000)
Result of the window 17:  redcro -> 0.662,     (Bird -> 0.911)
Result of the window 18:  redcro -> 0.899,     (Silence -> 1.000)
Result of the window 19:  redcro -> 0.959,     (Silence -> 1.000)
Result of the window 20:  redcro -> 0.968,     (Silence -> 1.000)
Result of the window 21:  houspa -> 0.994,     (Bird -> 0.990)
Result of the window 22:  redcro -> 0.967,     (Silence -> 1.000)
Result of the window 23:  redcro -> 0.991,     (Silence -> 1.000)
Result of the window 24:  redcro -> 0.981,     (Silence -> 1.000)
Result of the window 25:  houspa -> 0.604,     (Bird -> 0.987)
Result of the window 26:  redcro -> 0.944,     (Silence -> 1.000)
Result of the window 27:  redcro -> 0.938,     (Silence -> 1.000)
Result of the window 28:  redcro -> 0.907,     (Silence -> 1.000)
Result of the window 29:  redcro -> 0.937,     (Silence -> 1.000)
Result of the window 30:  redcro -> 0.940,     (Silence -> 1.000)
Result of the window 31:  redcro -> 0.566,     (Bird -> 0.990)
Result of the window 32:  redcro -> 0.672,     (Silence -> 1.000)
Result of the window 33:  redcro -> 0.776,     (Silence -> 1.000)
Mean result: redcro -> 0.7189410328865051

Exportando el modelo

El último paso es exportar su modelo para usarlo en dispositivos integrados o en el navegador.

La export método de exportación ambos formatos para usted.

models_path = './birds_models'
print(f'Exporing the TFLite model to {models_path}')

model.export(models_path, tflite_filename='my_birds_model.tflite')
Exporing the TFLite model to ./birds_models
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
2021-10-07 11:56:43.268990: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/tmplmr5go_i/assets
INFO:tensorflow:Assets written to: /tmp/tmplmr5go_i/assets
2021-10-07 11:56:49.292362: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-10-07 11:56:49.292431: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite

También puede exportar la versión SavedModel para servirla o usarla en un entorno Python.

model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets
INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets
INFO:tensorflow:Saving labels in ./birds_models/labels.txt
INFO:tensorflow:Saving labels in ./birds_models/labels.txt

Próximos pasos

Lo hiciste.

Ahora su nuevo modelo se puede implementar en dispositivos móviles que utilizan la API TFLite AudioClassifier de tareas .

También puede probar el mismo proceso con sus propios datos con diferentes clases y aquí está la documentación de modelo Maker para Clasificación de audio .

También aprender de aplicaciones de referencia de extremo a extremo: Android , iOS .