![]() |
![]() |
![]() |
![]() |
![]() |
In this colab notebook, you'll learn how to use the TensorFlow Lite Model Maker to train a custom audio classification model.
The Model Maker library uses transfer learning to simplify the process of training a TensorFlow Lite model using a custom dataset. Retraining a TensorFlow Lite model with your own custom dataset reduces the amount of training data and time required.
It is part of the Codelab to Customize an Audio model and deploy on Android.
You'll use a custom birds dataset and export a TFLite model that can be used on a phone, a TensorFlow.JS model that can be used for inference in the browser and also a SavedModel version that you can use for serving.
Intalling dependencies
sudo apt -y install libportaudio2
pip install tflite-model-maker
Import TensorFlow, Model Maker and other libraries
Among the dependencies that are needed, you'll use TensorFlow and Model Maker. Aside those, the others are for audio manipulation, playing and visualizations.
import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import glob
import random
from IPython.display import Audio, Image
from scipy.io import wavfile
print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")
TensorFlow Version: 2.8.3 Model Maker Version: 0.4.2
The Birds dataset
The Birds dataset is an education collection of 5 types of birds songs:
- White-breasted Wood-Wren
- House Sparrow
- Red Crossbill
- Chestnut-crowned Antpitta
- Azara's Spinetail
The original audio came from Xeno-canto which is a website dedicated to sharing bird sounds from all over the world.
Let's start by downloading the data.
birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',
'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',
cache_dir='./',
cache_subdir='dataset',
extract=True)
Downloading data from https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip 343687168/343680986 [==============================] - 3s 0us/step 343695360/343680986 [==============================] - 3s 0us/step
Explore the data
The audios are already split in train and test folders. Inside each split folder, there's one folder for each bird, using their bird_code
as name.
The audios are all mono and with 16kHz sample rate.
For more information about each file, you can read the metadata.csv
file. It contains all the files authors, lincenses and some more information. You won't need to read it yourself on this tutorial.
# @title [Run this] Util functions and data structures.
data_dir = './dataset/small_birds_dataset'
bird_code_to_name = {
'wbwwre1': 'White-breasted Wood-Wren',
'houspa': 'House Sparrow',
'redcro': 'Red Crossbill',
'chcant2': 'Chestnut-crowned Antpitta',
'azaspi1': "Azara's Spinetail",
}
birds_images = {
'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', # Alejandro Bayer Tamayo from Armenia, Colombia
'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', # Diliff
'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', # Elaine R. Wilson, www.naturespicsonline.com
'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', # Mike's Birds from Riverside, CA, US
'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368
}
test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))
def get_random_audio_file():
test_list = glob.glob(test_files)
random_audio_path = random.choice(test_list)
return random_audio_path
def show_bird_data(audio_path):
sample_rate, audio_data = wavfile.read(audio_path, 'rb')
bird_code = audio_path.split('/')[-2]
print(f'Bird name: {bird_code_to_name[bird_code]}')
print(f'Bird code: {bird_code}')
display(Image(birds_images[bird_code]))
plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
plt.title(plttitle)
plt.plot(audio_data)
display(Audio(audio_data, rate=sample_rate))
print('functions and data structures created')
functions and data structures created
Playing some audio
To have a better understanding about the data, lets listen to a random audio files from the test split.
random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: House Sparrow Bird code: houspa
Training the Model
When using Model Maker for audio, you have to start with a model spec. This is the base model that your new model will extract information to learn about the new classes. It also affects how the dataset will be transformed to respect the models spec parameters like: sample rate, number of channels.
YAMNet is an audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology.
It's input is expected to be at 16kHz and with 1 channel.
You don't need to do any resampling yourself. Model Maker takes care of that for you.
frame_length
is to decide how long each traininng sample is. in this caase EXPECTED_WAVEFORM_LENGTH * 3sframe_steps
is to decide how far appart are the training samples. In this case, the ith sample will start at EXPECTED_WAVEFORM_LENGTH * 6s after the (i-1)th sample.
The reason to set these values is to work around some limitation in real world dataset.
For example, in the bird dataset, birds don't sing all the time. They sing, rest and sing again, with noises in between. Having a long frame would help capture the singing, but setting it too long will reduce the number of samples for training.
spec = audio_classifier.YamNetSpec(
keep_yamnet_and_custom_heads=True,
frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)
INFO:tensorflow:Checkpoints are stored in /tmpfs/tmp/tmpdqml6_zs
Loading the data
Model Maker has the API to load the data from a folder and have it in the expected format for the model spec.
The train and test split are based on the folders. The validation dataset will be created as 20% of the train split.
train_data = audio_classifier.DataLoader.from_folder(
spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
spec, os.path.join(data_dir, 'test'), cache=True)
Training the model
the audio_classifier has the create
method that creates a model and already start training it.
You can customize many parameterss, for more information you can read more details in the documentation.
On this first try you'll use all the default configurations and train for 100 epochs.
batch_size = 128
epochs = 100
print('Training the model')
model = audio_classifier.create(
train_data,
spec,
validation_data,
batch_size=batch_size,
epochs=epochs)
Training the model Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= classification_head (Dense) (None, 5) 5125 ================================================================= Total params: 5,125 Trainable params: 5,125 Non-trainable params: 0 _________________________________________________________________ Epoch 1/100 24/24 [==============================] - 17s 606ms/step - loss: 1.4590 - acc: 0.3536 - val_loss: 1.3477 - val_acc: 0.4064 Epoch 2/100 24/24 [==============================] - 1s 23ms/step - loss: 1.2257 - acc: 0.5408 - val_loss: 1.1705 - val_acc: 0.6164 Epoch 3/100 24/24 [==============================] - 0s 17ms/step - loss: 1.0755 - acc: 0.6418 - val_loss: 1.0612 - val_acc: 0.6895 Epoch 4/100 24/24 [==============================] - 0s 18ms/step - loss: 0.9710 - acc: 0.6938 - val_loss: 0.9898 - val_acc: 0.7146 Epoch 5/100 24/24 [==============================] - 0s 18ms/step - loss: 0.8907 - acc: 0.7178 - val_loss: 0.9328 - val_acc: 0.7374 Epoch 6/100 24/24 [==============================] - 0s 17ms/step - loss: 0.8365 - acc: 0.7342 - val_loss: 0.8891 - val_acc: 0.7352 Epoch 7/100 24/24 [==============================] - 0s 19ms/step - loss: 0.7848 - acc: 0.7582 - val_loss: 0.8534 - val_acc: 0.7374 Epoch 8/100 24/24 [==============================] - 0s 17ms/step - loss: 0.7488 - acc: 0.7635 - val_loss: 0.8233 - val_acc: 0.7443 Epoch 9/100 24/24 [==============================] - 0s 18ms/step - loss: 0.7179 - acc: 0.7780 - val_loss: 0.7978 - val_acc: 0.7466 Epoch 10/100 24/24 [==============================] - 0s 18ms/step - loss: 0.6794 - acc: 0.7872 - val_loss: 0.7773 - val_acc: 0.7534 Epoch 11/100 24/24 [==============================] - 0s 18ms/step - loss: 0.6671 - acc: 0.7888 - val_loss: 0.7548 - val_acc: 0.7603 Epoch 12/100 24/24 [==============================] - 1s 22ms/step - loss: 0.6424 - acc: 0.7944 - val_loss: 0.7377 - val_acc: 0.7626 Epoch 13/100 24/24 [==============================] - 0s 19ms/step - loss: 0.6171 - acc: 0.8099 - val_loss: 0.7221 - val_acc: 0.7626 Epoch 14/100 24/24 [==============================] - 0s 19ms/step - loss: 0.6030 - acc: 0.8072 - val_loss: 0.7095 - val_acc: 0.7626 Epoch 15/100 24/24 [==============================] - 0s 17ms/step - loss: 0.5921 - acc: 0.8138 - val_loss: 0.6965 - val_acc: 0.7694 Epoch 16/100 24/24 [==============================] - 0s 17ms/step - loss: 0.5762 - acc: 0.8214 - val_loss: 0.6847 - val_acc: 0.7671 Epoch 17/100 24/24 [==============================] - 0s 18ms/step - loss: 0.5582 - acc: 0.8227 - val_loss: 0.6730 - val_acc: 0.7694 Epoch 18/100 24/24 [==============================] - 0s 18ms/step - loss: 0.5459 - acc: 0.8211 - val_loss: 0.6637 - val_acc: 0.7671 Epoch 19/100 24/24 [==============================] - 0s 18ms/step - loss: 0.5336 - acc: 0.8411 - val_loss: 0.6524 - val_acc: 0.7694 Epoch 20/100 24/24 [==============================] - 0s 17ms/step - loss: 0.5250 - acc: 0.8368 - val_loss: 0.6428 - val_acc: 0.7831 Epoch 21/100 24/24 [==============================] - 0s 18ms/step - loss: 0.5185 - acc: 0.8296 - val_loss: 0.6375 - val_acc: 0.7831 Epoch 22/100 24/24 [==============================] - 0s 17ms/step - loss: 0.5071 - acc: 0.8385 - val_loss: 0.6279 - val_acc: 0.7854 Epoch 23/100 24/24 [==============================] - 0s 18ms/step - loss: 0.5013 - acc: 0.8424 - val_loss: 0.6232 - val_acc: 0.7831 Epoch 24/100 24/24 [==============================] - 0s 18ms/step - loss: 0.4867 - acc: 0.8556 - val_loss: 0.6129 - val_acc: 0.7854 Epoch 25/100 24/24 [==============================] - 0s 18ms/step - loss: 0.4801 - acc: 0.8461 - val_loss: 0.6090 - val_acc: 0.7854 Epoch 26/100 24/24 [==============================] - 0s 18ms/step - loss: 0.4722 - acc: 0.8553 - val_loss: 0.6001 - val_acc: 0.7877 Epoch 27/100 24/24 [==============================] - 0s 17ms/step - loss: 0.4712 - acc: 0.8520 - val_loss: 0.5957 - val_acc: 0.7831 Epoch 28/100 24/24 [==============================] - 0s 18ms/step - loss: 0.4636 - acc: 0.8562 - val_loss: 0.5883 - val_acc: 0.7945 Epoch 29/100 24/24 [==============================] - 0s 19ms/step - loss: 0.4562 - acc: 0.8526 - val_loss: 0.5828 - val_acc: 0.8037 Epoch 30/100 24/24 [==============================] - 0s 17ms/step - loss: 0.4476 - acc: 0.8618 - val_loss: 0.5789 - val_acc: 0.7968 Epoch 31/100 24/24 [==============================] - 0s 19ms/step - loss: 0.4454 - acc: 0.8592 - val_loss: 0.5738 - val_acc: 0.8037 Epoch 32/100 24/24 [==============================] - 0s 17ms/step - loss: 0.4405 - acc: 0.8589 - val_loss: 0.5693 - val_acc: 0.8059 Epoch 33/100 24/24 [==============================] - 0s 17ms/step - loss: 0.4353 - acc: 0.8589 - val_loss: 0.5627 - val_acc: 0.8082 Epoch 34/100 24/24 [==============================] - 0s 19ms/step - loss: 0.4366 - acc: 0.8576 - val_loss: 0.5628 - val_acc: 0.8037 Epoch 35/100 24/24 [==============================] - 0s 17ms/step - loss: 0.4284 - acc: 0.8645 - val_loss: 0.5585 - val_acc: 0.8059 Epoch 36/100 24/24 [==============================] - 0s 18ms/step - loss: 0.4250 - acc: 0.8632 - val_loss: 0.5535 - val_acc: 0.8128 Epoch 37/100 24/24 [==============================] - 0s 18ms/step - loss: 0.4216 - acc: 0.8684 - val_loss: 0.5503 - val_acc: 0.8059 Epoch 38/100 24/24 [==============================] - 0s 17ms/step - loss: 0.4163 - acc: 0.8645 - val_loss: 0.5456 - val_acc: 0.8082 Epoch 39/100 24/24 [==============================] - 0s 17ms/step - loss: 0.4121 - acc: 0.8687 - val_loss: 0.5420 - val_acc: 0.8151 Epoch 40/100 24/24 [==============================] - 0s 17ms/step - loss: 0.4122 - acc: 0.8648 - val_loss: 0.5403 - val_acc: 0.8082 Epoch 41/100 24/24 [==============================] - 0s 18ms/step - loss: 0.4047 - acc: 0.8704 - val_loss: 0.5368 - val_acc: 0.8082 Epoch 42/100 24/24 [==============================] - 0s 18ms/step - loss: 0.4002 - acc: 0.8674 - val_loss: 0.5322 - val_acc: 0.8128 Epoch 43/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3990 - acc: 0.8760 - val_loss: 0.5302 - val_acc: 0.8105 Epoch 44/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3899 - acc: 0.8766 - val_loss: 0.5282 - val_acc: 0.8128 Epoch 45/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3892 - acc: 0.8740 - val_loss: 0.5291 - val_acc: 0.8105 Epoch 46/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3961 - acc: 0.8724 - val_loss: 0.5213 - val_acc: 0.8105 Epoch 47/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3801 - acc: 0.8783 - val_loss: 0.5192 - val_acc: 0.8105 Epoch 48/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3889 - acc: 0.8750 - val_loss: 0.5163 - val_acc: 0.8128 Epoch 49/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3799 - acc: 0.8786 - val_loss: 0.5151 - val_acc: 0.8128 Epoch 50/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3767 - acc: 0.8724 - val_loss: 0.5109 - val_acc: 0.8151 Epoch 51/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3752 - acc: 0.8766 - val_loss: 0.5093 - val_acc: 0.8151 Epoch 52/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3731 - acc: 0.8839 - val_loss: 0.5073 - val_acc: 0.8151 Epoch 53/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3715 - acc: 0.8766 - val_loss: 0.5057 - val_acc: 0.8151 Epoch 54/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3627 - acc: 0.8829 - val_loss: 0.5040 - val_acc: 0.8151 Epoch 55/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3716 - acc: 0.8763 - val_loss: 0.5002 - val_acc: 0.8151 Epoch 56/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3644 - acc: 0.8766 - val_loss: 0.4991 - val_acc: 0.8151 Epoch 57/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3655 - acc: 0.8809 - val_loss: 0.4956 - val_acc: 0.8196 Epoch 58/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3628 - acc: 0.8796 - val_loss: 0.4946 - val_acc: 0.8219 Epoch 59/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3675 - acc: 0.8789 - val_loss: 0.4903 - val_acc: 0.8265 Epoch 60/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3564 - acc: 0.8832 - val_loss: 0.4888 - val_acc: 0.8196 Epoch 61/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3588 - acc: 0.8799 - val_loss: 0.4876 - val_acc: 0.8265 Epoch 62/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3536 - acc: 0.8799 - val_loss: 0.4870 - val_acc: 0.8265 Epoch 63/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3471 - acc: 0.8875 - val_loss: 0.4843 - val_acc: 0.8288 Epoch 64/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3497 - acc: 0.8859 - val_loss: 0.4837 - val_acc: 0.8288 Epoch 65/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3441 - acc: 0.8803 - val_loss: 0.4826 - val_acc: 0.8265 Epoch 66/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3479 - acc: 0.8872 - val_loss: 0.4800 - val_acc: 0.8288 Epoch 67/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3482 - acc: 0.8839 - val_loss: 0.4802 - val_acc: 0.8242 Epoch 68/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3486 - acc: 0.8832 - val_loss: 0.4793 - val_acc: 0.8288 Epoch 69/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3428 - acc: 0.8852 - val_loss: 0.4777 - val_acc: 0.8311 Epoch 70/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3450 - acc: 0.8895 - val_loss: 0.4772 - val_acc: 0.8311 Epoch 71/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3433 - acc: 0.8862 - val_loss: 0.4773 - val_acc: 0.8288 Epoch 72/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3301 - acc: 0.8954 - val_loss: 0.4737 - val_acc: 0.8311 Epoch 73/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3333 - acc: 0.8918 - val_loss: 0.4738 - val_acc: 0.8311 Epoch 74/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3376 - acc: 0.8849 - val_loss: 0.4721 - val_acc: 0.8333 Epoch 75/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3391 - acc: 0.8878 - val_loss: 0.4737 - val_acc: 0.8311 Epoch 76/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3372 - acc: 0.8891 - val_loss: 0.4727 - val_acc: 0.8333 Epoch 77/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3301 - acc: 0.8888 - val_loss: 0.4714 - val_acc: 0.8288 Epoch 78/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3372 - acc: 0.8875 - val_loss: 0.4670 - val_acc: 0.8333 Epoch 79/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3327 - acc: 0.8921 - val_loss: 0.4665 - val_acc: 0.8311 Epoch 80/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3340 - acc: 0.8898 - val_loss: 0.4660 - val_acc: 0.8356 Epoch 81/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3343 - acc: 0.8885 - val_loss: 0.4656 - val_acc: 0.8311 Epoch 82/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3295 - acc: 0.8891 - val_loss: 0.4648 - val_acc: 0.8288 Epoch 83/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3235 - acc: 0.8944 - val_loss: 0.4614 - val_acc: 0.8333 Epoch 84/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3238 - acc: 0.8938 - val_loss: 0.4636 - val_acc: 0.8311 Epoch 85/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3285 - acc: 0.8908 - val_loss: 0.4601 - val_acc: 0.8333 Epoch 86/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3187 - acc: 0.8911 - val_loss: 0.4602 - val_acc: 0.8311 Epoch 87/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3180 - acc: 0.8964 - val_loss: 0.4545 - val_acc: 0.8356 Epoch 88/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3290 - acc: 0.8852 - val_loss: 0.4577 - val_acc: 0.8356 Epoch 89/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3121 - acc: 0.8967 - val_loss: 0.4578 - val_acc: 0.8333 Epoch 90/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3221 - acc: 0.8901 - val_loss: 0.4550 - val_acc: 0.8379 Epoch 91/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3156 - acc: 0.9003 - val_loss: 0.4550 - val_acc: 0.8333 Epoch 92/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3186 - acc: 0.8961 - val_loss: 0.4540 - val_acc: 0.8379 Epoch 93/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3258 - acc: 0.8941 - val_loss: 0.4506 - val_acc: 0.8402 Epoch 94/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3075 - acc: 0.8970 - val_loss: 0.4513 - val_acc: 0.8356 Epoch 95/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3169 - acc: 0.8947 - val_loss: 0.4512 - val_acc: 0.8356 Epoch 96/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3063 - acc: 0.8987 - val_loss: 0.4477 - val_acc: 0.8425 Epoch 97/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3114 - acc: 0.8924 - val_loss: 0.4466 - val_acc: 0.8402 Epoch 98/100 24/24 [==============================] - 0s 17ms/step - loss: 0.3031 - acc: 0.9003 - val_loss: 0.4484 - val_acc: 0.8402 Epoch 99/100 24/24 [==============================] - 0s 18ms/step - loss: 0.3099 - acc: 0.8984 - val_loss: 0.4479 - val_acc: 0.8356 Epoch 100/100 24/24 [==============================] - 0s 19ms/step - loss: 0.3100 - acc: 0.8941 - val_loss: 0.4479 - val_acc: 0.8379
The accuracy looks good but it's important to run the evaluation step on the test data and vefify your model achieved good results on unseed data.
print('Evaluating the model')
model.evaluate(test_data)
Evaluating the model 28/28 [==============================] - 4s 122ms/step - loss: 0.8122 - acc: 0.7865 [0.8122328519821167, 0.7864523530006409]
Understanding your model
When training a classifier, it's useful to see the confusion matrix. The confusion matrix gives you detailed knowledge of how your classifier is performing on test data.
Model Maker already creates the confusion matrix for you.
def show_confusion_matrix(confusion, test_labels):
"""Compute confusion matrix and normalize."""
confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
axis_labels = test_labels
ax = sns.heatmap(
confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
cmap='Blues', annot=True, fmt='.2f', square=True)
plt.title("Confusion matrix")
plt.ylabel("True label")
plt.xlabel("Predicted label")
confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)
Testing the model [Optional]
You can try the model on a sample audio from the test dataset just to see the results.
First you get the serving model.
serving_model = model.create_serving_model()
print(f'Model\'s input shape and type: {serving_model.inputs}')
print(f'Model\'s output shape and type: {serving_model.outputs}')
Model's input shape and type: [<KerasTensor: shape=(None, 15600) dtype=float32 (created by layer 'audio')>] Model's output shape and type: [<KerasTensor: shape=(None, 521) dtype=float32 (created by layer 'keras_layer')>, <KerasTensor: shape=(None, 5) dtype=float32 (created by layer 'sequential')>]
Coming back to the random audio you loaded earlier
# if you want to try another file just uncoment the line below
random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: White-breasted Wood-Wren Bird code: wbwwre1
The model created has a fixed input window.
For a given audio file, you'll have to split it in windows of data of the expected size. The last window might need to be filled with zeros.
sample_rate, audio_data = wavfile.read(random_audio, 'rb')
audio_data = np.array(audio_data) / tf.int16.max
input_size = serving_model.input_shape[1]
splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)
print(f'Test audio path: {random_audio}')
print(f'Original size of the audio data: {len(audio_data)}')
print(f'Number of windows for inference: {len(splitted_audio_data)}')
Test audio path: /tmpfs/src/temp/tensorflow/lite/g3doc/models/modify/model_maker/dataset/small_birds_dataset/test/wbwwre1/XC519211.wav Original size of the audio data: 728160 Number of windows for inference: 47
You'll loop over all the splitted audio and apply the model for each one of them.
The model you've just trained has 2 outputs: The original YAMNet's output and the one you've just trained. This is important because the real world environment is more complicated than just bird sounds. You can use the YAMNet's output to filter out non relevant audio, for example, on the birds use case, if YAMNet is not classifying Birds or Animals, this might show that the output from your model might have an irrelevant classification.
Below both outpus are printed to make it easier to understand their relation. Most of the mistakes that your model make are when YAMNet's prediction is not related to your domain (eg: birds).
print(random_audio)
results = []
print('Result of the window ith: your model class -> score, (spec class -> score)')
for i, data in enumerate(splitted_audio_data):
yamnet_output, inference = serving_model(data)
results.append(inference[0].numpy())
result_index = tf.argmax(inference[0])
spec_result_index = tf.argmax(yamnet_output[0])
t = spec._yamnet_labels()[spec_result_index]
result_str = f'Result of the window {i}: ' \
f'\t{test_data.index_to_label[result_index]} -> {inference[0][result_index].numpy():.3f}, ' \
f'\t({spec._yamnet_labels()[spec_result_index]} -> {yamnet_output[0][spec_result_index]:.3f})'
print(result_str)
results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'Mean result: {test_data.index_to_label[result_index]} -> {mean_results[result_index]}')
/tmpfs/src/temp/tensorflow/lite/g3doc/models/modify/model_maker/dataset/small_birds_dataset/test/wbwwre1/XC519211.wav Result of the window ith: your model class -> score, (spec class -> score) Result of the window 0: wbwwre1 -> 0.947, (Aircraft -> 0.193) Result of the window 1: wbwwre1 -> 0.969, (Outside, rural or natural -> 0.385) Result of the window 2: wbwwre1 -> 0.982, (Outside, rural or natural -> 0.342) Result of the window 3: wbwwre1 -> 0.994, (Cricket -> 0.630) Result of the window 4: wbwwre1 -> 0.912, (Insect -> 0.671) Result of the window 5: wbwwre1 -> 0.983, (Cricket -> 0.892) Result of the window 6: wbwwre1 -> 0.975, (Cricket -> 0.872) Result of the window 7: wbwwre1 -> 1.000, (Bird vocalization, bird call, bird song -> 0.661) Result of the window 8: wbwwre1 -> 1.000, (Wild animals -> 0.351) Result of the window 9: wbwwre1 -> 1.000, (Bird -> 0.445) Result of the window 10: wbwwre1 -> 0.884, (Bird vocalization, bird call, bird song -> 0.758) Result of the window 11: wbwwre1 -> 0.996, (Theremin -> 0.388) Result of the window 12: wbwwre1 -> 1.000, (Bird -> 0.342) Result of the window 13: wbwwre1 -> 0.999, (Bird vocalization, bird call, bird song -> 0.192) Result of the window 14: wbwwre1 -> 1.000, (Bird -> 0.457) Result of the window 15: wbwwre1 -> 0.987, (Bird vocalization, bird call, bird song -> 0.416) Result of the window 16: wbwwre1 -> 0.999, (Bird vocalization, bird call, bird song -> 0.200) Result of the window 17: wbwwre1 -> 0.999, (Bird vocalization, bird call, bird song -> 0.704) Result of the window 18: wbwwre1 -> 0.991, (Bird vocalization, bird call, bird song -> 0.418) Result of the window 19: wbwwre1 -> 0.989, (Bird -> 0.718) Result of the window 20: wbwwre1 -> 0.973, (Animal -> 0.570) Result of the window 21: wbwwre1 -> 0.962, (Environmental noise -> 0.430) Result of the window 22: wbwwre1 -> 0.932, (Animal -> 0.597) Result of the window 23: wbwwre1 -> 0.969, (Bird vocalization, bird call, bird song -> 0.727) Result of the window 24: wbwwre1 -> 0.920, (Cricket -> 0.629) Result of the window 25: wbwwre1 -> 0.877, (Outside, rural or natural -> 0.189) Result of the window 26: wbwwre1 -> 0.962, (Insect -> 0.318) Result of the window 27: wbwwre1 -> 0.871, (Outside, rural or natural -> 0.431) Result of the window 28: wbwwre1 -> 0.983, (Cricket -> 0.383) Result of the window 29: wbwwre1 -> 0.984, (Cricket -> 0.602) Result of the window 30: redcro -> 0.455, (Silence -> 1.000) Result of the window 31: azaspi1 -> 0.373, (Silence -> 1.000) Result of the window 32: chcant2 -> 0.835, (Silence -> 1.000) Result of the window 33: chcant2 -> 0.911, (Speech -> 0.962) Result of the window 34: chcant2 -> 0.588, (Speech -> 0.876) Result of the window 35: chcant2 -> 0.319, (Speech -> 0.977) Result of the window 36: azaspi1 -> 0.444, (Silence -> 1.000) Result of the window 37: chcant2 -> 0.999, (Silence -> 1.000) Result of the window 38: houspa -> 0.585, (Bird -> 0.880) Result of the window 39: wbwwre1 -> 0.898, (Bird -> 0.809) Result of the window 40: houspa -> 0.472, (Animal -> 0.556) Result of the window 41: wbwwre1 -> 0.977, (Bird vocalization, bird call, bird song -> 0.708) Result of the window 42: wbwwre1 -> 0.949, (Insect -> 0.266) Result of the window 43: wbwwre1 -> 0.980, (Cricket -> 0.928) Result of the window 44: wbwwre1 -> 0.992, (Cricket -> 0.945) Result of the window 45: chcant2 -> 0.571, (Speech -> 0.158) Result of the window 46: chcant2 -> 0.787, (Silence -> 0.844) Mean result: wbwwre1 -> 0.7325854897499084
Exporting the model
The last step is exporting your model to be used on embedded devices or on the browser.
The export
method export both formats for you.
models_path = './birds_models'
print(f'Exporing the TFLite model to {models_path}')
model.export(models_path, tflite_filename='my_birds_model.tflite')
Exporing the TFLite model to ./birds_models 2022-10-20 12:11:30.821235: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphxc1u85r/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphxc1u85r/assets 2022-10-20 12:11:36.989028: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format. 2022-10-20 12:11:36.989080: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency. INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite
You can also export the SavedModel version for serving or using on a Python environment.
model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets INFO:tensorflow:Saving labels in ./birds_models/labels.txt INFO:tensorflow:Saving labels in ./birds_models/labels.txt
Next Steps
You did it.
Now your new model can be deployed on mobile devices using TFLite AudioClassifier Task API.
You can also try the same process with your own data with different classes and here is the documentation for Model Maker for Audio Classification.