Transfer Learning for the Audio Domain with TensorFlow Lite Model Maker

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook See TF Hub model

In this colab notebook, you'll learn how to use the TensorFlow Lite Model Maker to train a custom audio classification model.

The Model Maker library uses transfer learning to simplify the process of training a TensorFlow Lite model using a custom dataset. Retraining a TensorFlow Lite model with your own custom dataset reduces the amount of training data and time required.

It is part of the Codelab to Customize an Audio model and deploy on Android.

You'll use a custom birds dataset and export a TFLite model that can be used on a phone, a TensorFlow.JS model that can be used for inference in the browser and also a SavedModel version that you can use for serving.

Intalling dependencies

sudo apt -y install libportaudio2
pip install tflite-model-maker

Import TensorFlow, Model Maker and other libraries

Among the dependencies that are needed, you'll use TensorFlow and Model Maker. Aside those, the others are for audio manipulation, playing and visualizations.

import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import itertools
import glob
import random

from IPython.display import Audio, Image
from scipy.io import wavfile

print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")
TensorFlow Version: 2.8.3
Model Maker Version: 0.4.2

The Birds dataset

The Birds dataset is an education collection of 5 types of birds songs:

  • White-breasted Wood-Wren
  • House Sparrow
  • Red Crossbill
  • Chestnut-crowned Antpitta
  • Azara's Spinetail

The original audio came from Xeno-canto which is a website dedicated to sharing bird sounds from all over the world.

Let's start by downloading the data.

birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',
                                                'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',
                                                cache_dir='./',
                                                cache_subdir='dataset',
                                                extract=True)
Downloading data from https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip
343687168/343680986 [==============================] - 3s 0us/step
343695360/343680986 [==============================] - 3s 0us/step

Explore the data

The audios are already split in train and test folders. Inside each split folder, there's one folder for each bird, using their bird_code as name.

The audios are all mono and with 16kHz sample rate.

For more information about each file, you can read the metadata.csv file. It contains all the files authors, lincenses and some more information. You won't need to read it yourself on this tutorial.

# @title [Run this] Util functions and data structures.

data_dir = './dataset/small_birds_dataset'

bird_code_to_name = {
  'wbwwre1': 'White-breasted Wood-Wren',
  'houspa': 'House Sparrow',
  'redcro': 'Red Crossbill',  
  'chcant2': 'Chestnut-crowned Antpitta',
  'azaspi1': "Azara's Spinetail",   
}

birds_images = {
  'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', #   Alejandro Bayer Tamayo from Armenia, Colombia 
  'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', #    Diliff
  'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', #  Elaine R. Wilson, www.naturespicsonline.com
  'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', #   Mike's Birds from Riverside, CA, US
  'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368
}

test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))

def get_random_audio_file():
  test_list = glob.glob(test_files)
  random_audio_path = random.choice(test_list)
  return random_audio_path


def show_bird_data(audio_path):
  sample_rate, audio_data = wavfile.read(audio_path, 'rb')

  bird_code = audio_path.split('/')[-2]
  print(f'Bird name: {bird_code_to_name[bird_code]}')
  print(f'Bird code: {bird_code}')
  display(Image(birds_images[bird_code]))

  plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
  plt.title(plttitle)
  plt.plot(audio_data)
  display(Audio(audio_data, rate=sample_rate))

print('functions and data structures created')
functions and data structures created

Playing some audio

To have a better understanding about the data, lets listen to a random audio files from the test split.

random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: House Sparrow
Bird code: houspa

jpeg

png

Training the Model

When using Model Maker for audio, you have to start with a model spec. This is the base model that your new model will extract information to learn about the new classes. It also affects how the dataset will be transformed to respect the models spec parameters like: sample rate, number of channels.

YAMNet is an audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology.

It's input is expected to be at 16kHz and with 1 channel.

You don't need to do any resampling yourself. Model Maker takes care of that for you.

  • frame_length is to decide how long each traininng sample is. in this caase EXPECTED_WAVEFORM_LENGTH * 3s

  • frame_steps is to decide how far appart are the training samples. In this case, the ith sample will start at EXPECTED_WAVEFORM_LENGTH * 6s after the (i-1)th sample.

The reason to set these values is to work around some limitation in real world dataset.

For example, in the bird dataset, birds don't sing all the time. They sing, rest and sing again, with noises in between. Having a long frame would help capture the singing, but setting it too long will reduce the number of samples for training.

spec = audio_classifier.YamNetSpec(
    keep_yamnet_and_custom_heads=True,
    frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
    frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)
INFO:tensorflow:Checkpoints are stored in /tmpfs/tmp/tmpdqml6_zs

Loading the data

Model Maker has the API to load the data from a folder and have it in the expected format for the model spec.

The train and test split are based on the folders. The validation dataset will be created as 20% of the train split.

train_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'test'), cache=True)

Training the model

the audio_classifier has the create method that creates a model and already start training it.

You can customize many parameterss, for more information you can read more details in the documentation.

On this first try you'll use all the default configurations and train for 100 epochs.

batch_size = 128
epochs = 100

print('Training the model')
model = audio_classifier.create(
    train_data,
    spec,
    validation_data,
    batch_size=batch_size,
    epochs=epochs)
Training the model
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 classification_head (Dense)  (None, 5)                5125      
                                                                 
=================================================================
Total params: 5,125
Trainable params: 5,125
Non-trainable params: 0
_________________________________________________________________
Epoch 1/100
24/24 [==============================] - 17s 606ms/step - loss: 1.4590 - acc: 0.3536 - val_loss: 1.3477 - val_acc: 0.4064
Epoch 2/100
24/24 [==============================] - 1s 23ms/step - loss: 1.2257 - acc: 0.5408 - val_loss: 1.1705 - val_acc: 0.6164
Epoch 3/100
24/24 [==============================] - 0s 17ms/step - loss: 1.0755 - acc: 0.6418 - val_loss: 1.0612 - val_acc: 0.6895
Epoch 4/100
24/24 [==============================] - 0s 18ms/step - loss: 0.9710 - acc: 0.6938 - val_loss: 0.9898 - val_acc: 0.7146
Epoch 5/100
24/24 [==============================] - 0s 18ms/step - loss: 0.8907 - acc: 0.7178 - val_loss: 0.9328 - val_acc: 0.7374
Epoch 6/100
24/24 [==============================] - 0s 17ms/step - loss: 0.8365 - acc: 0.7342 - val_loss: 0.8891 - val_acc: 0.7352
Epoch 7/100
24/24 [==============================] - 0s 19ms/step - loss: 0.7848 - acc: 0.7582 - val_loss: 0.8534 - val_acc: 0.7374
Epoch 8/100
24/24 [==============================] - 0s 17ms/step - loss: 0.7488 - acc: 0.7635 - val_loss: 0.8233 - val_acc: 0.7443
Epoch 9/100
24/24 [==============================] - 0s 18ms/step - loss: 0.7179 - acc: 0.7780 - val_loss: 0.7978 - val_acc: 0.7466
Epoch 10/100
24/24 [==============================] - 0s 18ms/step - loss: 0.6794 - acc: 0.7872 - val_loss: 0.7773 - val_acc: 0.7534
Epoch 11/100
24/24 [==============================] - 0s 18ms/step - loss: 0.6671 - acc: 0.7888 - val_loss: 0.7548 - val_acc: 0.7603
Epoch 12/100
24/24 [==============================] - 1s 22ms/step - loss: 0.6424 - acc: 0.7944 - val_loss: 0.7377 - val_acc: 0.7626
Epoch 13/100
24/24 [==============================] - 0s 19ms/step - loss: 0.6171 - acc: 0.8099 - val_loss: 0.7221 - val_acc: 0.7626
Epoch 14/100
24/24 [==============================] - 0s 19ms/step - loss: 0.6030 - acc: 0.8072 - val_loss: 0.7095 - val_acc: 0.7626
Epoch 15/100
24/24 [==============================] - 0s 17ms/step - loss: 0.5921 - acc: 0.8138 - val_loss: 0.6965 - val_acc: 0.7694
Epoch 16/100
24/24 [==============================] - 0s 17ms/step - loss: 0.5762 - acc: 0.8214 - val_loss: 0.6847 - val_acc: 0.7671
Epoch 17/100
24/24 [==============================] - 0s 18ms/step - loss: 0.5582 - acc: 0.8227 - val_loss: 0.6730 - val_acc: 0.7694
Epoch 18/100
24/24 [==============================] - 0s 18ms/step - loss: 0.5459 - acc: 0.8211 - val_loss: 0.6637 - val_acc: 0.7671
Epoch 19/100
24/24 [==============================] - 0s 18ms/step - loss: 0.5336 - acc: 0.8411 - val_loss: 0.6524 - val_acc: 0.7694
Epoch 20/100
24/24 [==============================] - 0s 17ms/step - loss: 0.5250 - acc: 0.8368 - val_loss: 0.6428 - val_acc: 0.7831
Epoch 21/100
24/24 [==============================] - 0s 18ms/step - loss: 0.5185 - acc: 0.8296 - val_loss: 0.6375 - val_acc: 0.7831
Epoch 22/100
24/24 [==============================] - 0s 17ms/step - loss: 0.5071 - acc: 0.8385 - val_loss: 0.6279 - val_acc: 0.7854
Epoch 23/100
24/24 [==============================] - 0s 18ms/step - loss: 0.5013 - acc: 0.8424 - val_loss: 0.6232 - val_acc: 0.7831
Epoch 24/100
24/24 [==============================] - 0s 18ms/step - loss: 0.4867 - acc: 0.8556 - val_loss: 0.6129 - val_acc: 0.7854
Epoch 25/100
24/24 [==============================] - 0s 18ms/step - loss: 0.4801 - acc: 0.8461 - val_loss: 0.6090 - val_acc: 0.7854
Epoch 26/100
24/24 [==============================] - 0s 18ms/step - loss: 0.4722 - acc: 0.8553 - val_loss: 0.6001 - val_acc: 0.7877
Epoch 27/100
24/24 [==============================] - 0s 17ms/step - loss: 0.4712 - acc: 0.8520 - val_loss: 0.5957 - val_acc: 0.7831
Epoch 28/100
24/24 [==============================] - 0s 18ms/step - loss: 0.4636 - acc: 0.8562 - val_loss: 0.5883 - val_acc: 0.7945
Epoch 29/100
24/24 [==============================] - 0s 19ms/step - loss: 0.4562 - acc: 0.8526 - val_loss: 0.5828 - val_acc: 0.8037
Epoch 30/100
24/24 [==============================] - 0s 17ms/step - loss: 0.4476 - acc: 0.8618 - val_loss: 0.5789 - val_acc: 0.7968
Epoch 31/100
24/24 [==============================] - 0s 19ms/step - loss: 0.4454 - acc: 0.8592 - val_loss: 0.5738 - val_acc: 0.8037
Epoch 32/100
24/24 [==============================] - 0s 17ms/step - loss: 0.4405 - acc: 0.8589 - val_loss: 0.5693 - val_acc: 0.8059
Epoch 33/100
24/24 [==============================] - 0s 17ms/step - loss: 0.4353 - acc: 0.8589 - val_loss: 0.5627 - val_acc: 0.8082
Epoch 34/100
24/24 [==============================] - 0s 19ms/step - loss: 0.4366 - acc: 0.8576 - val_loss: 0.5628 - val_acc: 0.8037
Epoch 35/100
24/24 [==============================] - 0s 17ms/step - loss: 0.4284 - acc: 0.8645 - val_loss: 0.5585 - val_acc: 0.8059
Epoch 36/100
24/24 [==============================] - 0s 18ms/step - loss: 0.4250 - acc: 0.8632 - val_loss: 0.5535 - val_acc: 0.8128
Epoch 37/100
24/24 [==============================] - 0s 18ms/step - loss: 0.4216 - acc: 0.8684 - val_loss: 0.5503 - val_acc: 0.8059
Epoch 38/100
24/24 [==============================] - 0s 17ms/step - loss: 0.4163 - acc: 0.8645 - val_loss: 0.5456 - val_acc: 0.8082
Epoch 39/100
24/24 [==============================] - 0s 17ms/step - loss: 0.4121 - acc: 0.8687 - val_loss: 0.5420 - val_acc: 0.8151
Epoch 40/100
24/24 [==============================] - 0s 17ms/step - loss: 0.4122 - acc: 0.8648 - val_loss: 0.5403 - val_acc: 0.8082
Epoch 41/100
24/24 [==============================] - 0s 18ms/step - loss: 0.4047 - acc: 0.8704 - val_loss: 0.5368 - val_acc: 0.8082
Epoch 42/100
24/24 [==============================] - 0s 18ms/step - loss: 0.4002 - acc: 0.8674 - val_loss: 0.5322 - val_acc: 0.8128
Epoch 43/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3990 - acc: 0.8760 - val_loss: 0.5302 - val_acc: 0.8105
Epoch 44/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3899 - acc: 0.8766 - val_loss: 0.5282 - val_acc: 0.8128
Epoch 45/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3892 - acc: 0.8740 - val_loss: 0.5291 - val_acc: 0.8105
Epoch 46/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3961 - acc: 0.8724 - val_loss: 0.5213 - val_acc: 0.8105
Epoch 47/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3801 - acc: 0.8783 - val_loss: 0.5192 - val_acc: 0.8105
Epoch 48/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3889 - acc: 0.8750 - val_loss: 0.5163 - val_acc: 0.8128
Epoch 49/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3799 - acc: 0.8786 - val_loss: 0.5151 - val_acc: 0.8128
Epoch 50/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3767 - acc: 0.8724 - val_loss: 0.5109 - val_acc: 0.8151
Epoch 51/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3752 - acc: 0.8766 - val_loss: 0.5093 - val_acc: 0.8151
Epoch 52/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3731 - acc: 0.8839 - val_loss: 0.5073 - val_acc: 0.8151
Epoch 53/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3715 - acc: 0.8766 - val_loss: 0.5057 - val_acc: 0.8151
Epoch 54/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3627 - acc: 0.8829 - val_loss: 0.5040 - val_acc: 0.8151
Epoch 55/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3716 - acc: 0.8763 - val_loss: 0.5002 - val_acc: 0.8151
Epoch 56/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3644 - acc: 0.8766 - val_loss: 0.4991 - val_acc: 0.8151
Epoch 57/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3655 - acc: 0.8809 - val_loss: 0.4956 - val_acc: 0.8196
Epoch 58/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3628 - acc: 0.8796 - val_loss: 0.4946 - val_acc: 0.8219
Epoch 59/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3675 - acc: 0.8789 - val_loss: 0.4903 - val_acc: 0.8265
Epoch 60/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3564 - acc: 0.8832 - val_loss: 0.4888 - val_acc: 0.8196
Epoch 61/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3588 - acc: 0.8799 - val_loss: 0.4876 - val_acc: 0.8265
Epoch 62/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3536 - acc: 0.8799 - val_loss: 0.4870 - val_acc: 0.8265
Epoch 63/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3471 - acc: 0.8875 - val_loss: 0.4843 - val_acc: 0.8288
Epoch 64/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3497 - acc: 0.8859 - val_loss: 0.4837 - val_acc: 0.8288
Epoch 65/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3441 - acc: 0.8803 - val_loss: 0.4826 - val_acc: 0.8265
Epoch 66/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3479 - acc: 0.8872 - val_loss: 0.4800 - val_acc: 0.8288
Epoch 67/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3482 - acc: 0.8839 - val_loss: 0.4802 - val_acc: 0.8242
Epoch 68/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3486 - acc: 0.8832 - val_loss: 0.4793 - val_acc: 0.8288
Epoch 69/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3428 - acc: 0.8852 - val_loss: 0.4777 - val_acc: 0.8311
Epoch 70/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3450 - acc: 0.8895 - val_loss: 0.4772 - val_acc: 0.8311
Epoch 71/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3433 - acc: 0.8862 - val_loss: 0.4773 - val_acc: 0.8288
Epoch 72/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3301 - acc: 0.8954 - val_loss: 0.4737 - val_acc: 0.8311
Epoch 73/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3333 - acc: 0.8918 - val_loss: 0.4738 - val_acc: 0.8311
Epoch 74/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3376 - acc: 0.8849 - val_loss: 0.4721 - val_acc: 0.8333
Epoch 75/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3391 - acc: 0.8878 - val_loss: 0.4737 - val_acc: 0.8311
Epoch 76/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3372 - acc: 0.8891 - val_loss: 0.4727 - val_acc: 0.8333
Epoch 77/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3301 - acc: 0.8888 - val_loss: 0.4714 - val_acc: 0.8288
Epoch 78/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3372 - acc: 0.8875 - val_loss: 0.4670 - val_acc: 0.8333
Epoch 79/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3327 - acc: 0.8921 - val_loss: 0.4665 - val_acc: 0.8311
Epoch 80/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3340 - acc: 0.8898 - val_loss: 0.4660 - val_acc: 0.8356
Epoch 81/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3343 - acc: 0.8885 - val_loss: 0.4656 - val_acc: 0.8311
Epoch 82/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3295 - acc: 0.8891 - val_loss: 0.4648 - val_acc: 0.8288
Epoch 83/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3235 - acc: 0.8944 - val_loss: 0.4614 - val_acc: 0.8333
Epoch 84/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3238 - acc: 0.8938 - val_loss: 0.4636 - val_acc: 0.8311
Epoch 85/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3285 - acc: 0.8908 - val_loss: 0.4601 - val_acc: 0.8333
Epoch 86/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3187 - acc: 0.8911 - val_loss: 0.4602 - val_acc: 0.8311
Epoch 87/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3180 - acc: 0.8964 - val_loss: 0.4545 - val_acc: 0.8356
Epoch 88/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3290 - acc: 0.8852 - val_loss: 0.4577 - val_acc: 0.8356
Epoch 89/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3121 - acc: 0.8967 - val_loss: 0.4578 - val_acc: 0.8333
Epoch 90/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3221 - acc: 0.8901 - val_loss: 0.4550 - val_acc: 0.8379
Epoch 91/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3156 - acc: 0.9003 - val_loss: 0.4550 - val_acc: 0.8333
Epoch 92/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3186 - acc: 0.8961 - val_loss: 0.4540 - val_acc: 0.8379
Epoch 93/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3258 - acc: 0.8941 - val_loss: 0.4506 - val_acc: 0.8402
Epoch 94/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3075 - acc: 0.8970 - val_loss: 0.4513 - val_acc: 0.8356
Epoch 95/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3169 - acc: 0.8947 - val_loss: 0.4512 - val_acc: 0.8356
Epoch 96/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3063 - acc: 0.8987 - val_loss: 0.4477 - val_acc: 0.8425
Epoch 97/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3114 - acc: 0.8924 - val_loss: 0.4466 - val_acc: 0.8402
Epoch 98/100
24/24 [==============================] - 0s 17ms/step - loss: 0.3031 - acc: 0.9003 - val_loss: 0.4484 - val_acc: 0.8402
Epoch 99/100
24/24 [==============================] - 0s 18ms/step - loss: 0.3099 - acc: 0.8984 - val_loss: 0.4479 - val_acc: 0.8356
Epoch 100/100
24/24 [==============================] - 0s 19ms/step - loss: 0.3100 - acc: 0.8941 - val_loss: 0.4479 - val_acc: 0.8379

The accuracy looks good but it's important to run the evaluation step on the test data and vefify your model achieved good results on unseed data.

print('Evaluating the model')
model.evaluate(test_data)
Evaluating the model
28/28 [==============================] - 4s 122ms/step - loss: 0.8122 - acc: 0.7865
[0.8122328519821167, 0.7864523530006409]

Understanding your model

When training a classifier, it's useful to see the confusion matrix. The confusion matrix gives you detailed knowledge of how your classifier is performing on test data.

Model Maker already creates the confusion matrix for you.

def show_confusion_matrix(confusion, test_labels):
  """Compute confusion matrix and normalize."""
  confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
  axis_labels = test_labels
  ax = sns.heatmap(
      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
      cmap='Blues', annot=True, fmt='.2f', square=True)
  plt.title("Confusion matrix")
  plt.ylabel("True label")
  plt.xlabel("Predicted label")

confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)

png

Testing the model [Optional]

You can try the model on a sample audio from the test dataset just to see the results.

First you get the serving model.

serving_model = model.create_serving_model()

print(f'Model\'s input shape and type: {serving_model.inputs}')
print(f'Model\'s output shape and type: {serving_model.outputs}')
Model's input shape and type: [<KerasTensor: shape=(None, 15600) dtype=float32 (created by layer 'audio')>]
Model's output shape and type: [<KerasTensor: shape=(None, 521) dtype=float32 (created by layer 'keras_layer')>, <KerasTensor: shape=(None, 5) dtype=float32 (created by layer 'sequential')>]

Coming back to the random audio you loaded earlier

# if you want to try another file just uncoment the line below
random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: White-breasted Wood-Wren
Bird code: wbwwre1

jpeg

png

The model created has a fixed input window.

For a given audio file, you'll have to split it in windows of data of the expected size. The last window might need to be filled with zeros.

sample_rate, audio_data = wavfile.read(random_audio, 'rb')

audio_data = np.array(audio_data) / tf.int16.max
input_size = serving_model.input_shape[1]

splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)

print(f'Test audio path: {random_audio}')
print(f'Original size of the audio data: {len(audio_data)}')
print(f'Number of windows for inference: {len(splitted_audio_data)}')
Test audio path: /tmpfs/src/temp/tensorflow/lite/g3doc/models/modify/model_maker/dataset/small_birds_dataset/test/wbwwre1/XC519211.wav
Original size of the audio data: 728160
Number of windows for inference: 47

You'll loop over all the splitted audio and apply the model for each one of them.

The model you've just trained has 2 outputs: The original YAMNet's output and the one you've just trained. This is important because the real world environment is more complicated than just bird sounds. You can use the YAMNet's output to filter out non relevant audio, for example, on the birds use case, if YAMNet is not classifying Birds or Animals, this might show that the output from your model might have an irrelevant classification.

Below both outpus are printed to make it easier to understand their relation. Most of the mistakes that your model make are when YAMNet's prediction is not related to your domain (eg: birds).

print(random_audio)

results = []
print('Result of the window ith:  your model class -> score,  (spec class -> score)')
for i, data in enumerate(splitted_audio_data):
  yamnet_output, inference = serving_model(data)
  results.append(inference[0].numpy())
  result_index = tf.argmax(inference[0])
  spec_result_index = tf.argmax(yamnet_output[0])
  t = spec._yamnet_labels()[spec_result_index]
  result_str = f'Result of the window {i}: ' \
  f'\t{test_data.index_to_label[result_index]} -> {inference[0][result_index].numpy():.3f}, ' \
  f'\t({spec._yamnet_labels()[spec_result_index]} -> {yamnet_output[0][spec_result_index]:.3f})'
  print(result_str)


results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'Mean result: {test_data.index_to_label[result_index]} -> {mean_results[result_index]}')
/tmpfs/src/temp/tensorflow/lite/g3doc/models/modify/model_maker/dataset/small_birds_dataset/test/wbwwre1/XC519211.wav
Result of the window ith:  your model class -> score,  (spec class -> score)
Result of the window 0:   wbwwre1 -> 0.947,    (Aircraft -> 0.193)
Result of the window 1:   wbwwre1 -> 0.969,    (Outside, rural or natural -> 0.385)
Result of the window 2:   wbwwre1 -> 0.982,    (Outside, rural or natural -> 0.342)
Result of the window 3:   wbwwre1 -> 0.994,    (Cricket -> 0.630)
Result of the window 4:   wbwwre1 -> 0.912,    (Insect -> 0.671)
Result of the window 5:   wbwwre1 -> 0.983,    (Cricket -> 0.892)
Result of the window 6:   wbwwre1 -> 0.975,    (Cricket -> 0.872)
Result of the window 7:   wbwwre1 -> 1.000,    (Bird vocalization, bird call, bird song -> 0.661)
Result of the window 8:   wbwwre1 -> 1.000,    (Wild animals -> 0.351)
Result of the window 9:   wbwwre1 -> 1.000,    (Bird -> 0.445)
Result of the window 10:  wbwwre1 -> 0.884,    (Bird vocalization, bird call, bird song -> 0.758)
Result of the window 11:  wbwwre1 -> 0.996,    (Theremin -> 0.388)
Result of the window 12:  wbwwre1 -> 1.000,    (Bird -> 0.342)
Result of the window 13:  wbwwre1 -> 0.999,    (Bird vocalization, bird call, bird song -> 0.192)
Result of the window 14:  wbwwre1 -> 1.000,    (Bird -> 0.457)
Result of the window 15:  wbwwre1 -> 0.987,    (Bird vocalization, bird call, bird song -> 0.416)
Result of the window 16:  wbwwre1 -> 0.999,    (Bird vocalization, bird call, bird song -> 0.200)
Result of the window 17:  wbwwre1 -> 0.999,    (Bird vocalization, bird call, bird song -> 0.704)
Result of the window 18:  wbwwre1 -> 0.991,    (Bird vocalization, bird call, bird song -> 0.418)
Result of the window 19:  wbwwre1 -> 0.989,    (Bird -> 0.718)
Result of the window 20:  wbwwre1 -> 0.973,    (Animal -> 0.570)
Result of the window 21:  wbwwre1 -> 0.962,    (Environmental noise -> 0.430)
Result of the window 22:  wbwwre1 -> 0.932,    (Animal -> 0.597)
Result of the window 23:  wbwwre1 -> 0.969,    (Bird vocalization, bird call, bird song -> 0.727)
Result of the window 24:  wbwwre1 -> 0.920,    (Cricket -> 0.629)
Result of the window 25:  wbwwre1 -> 0.877,    (Outside, rural or natural -> 0.189)
Result of the window 26:  wbwwre1 -> 0.962,    (Insect -> 0.318)
Result of the window 27:  wbwwre1 -> 0.871,    (Outside, rural or natural -> 0.431)
Result of the window 28:  wbwwre1 -> 0.983,    (Cricket -> 0.383)
Result of the window 29:  wbwwre1 -> 0.984,    (Cricket -> 0.602)
Result of the window 30:  redcro -> 0.455,     (Silence -> 1.000)
Result of the window 31:  azaspi1 -> 0.373,    (Silence -> 1.000)
Result of the window 32:  chcant2 -> 0.835,    (Silence -> 1.000)
Result of the window 33:  chcant2 -> 0.911,    (Speech -> 0.962)
Result of the window 34:  chcant2 -> 0.588,    (Speech -> 0.876)
Result of the window 35:  chcant2 -> 0.319,    (Speech -> 0.977)
Result of the window 36:  azaspi1 -> 0.444,    (Silence -> 1.000)
Result of the window 37:  chcant2 -> 0.999,    (Silence -> 1.000)
Result of the window 38:  houspa -> 0.585,     (Bird -> 0.880)
Result of the window 39:  wbwwre1 -> 0.898,    (Bird -> 0.809)
Result of the window 40:  houspa -> 0.472,     (Animal -> 0.556)
Result of the window 41:  wbwwre1 -> 0.977,    (Bird vocalization, bird call, bird song -> 0.708)
Result of the window 42:  wbwwre1 -> 0.949,    (Insect -> 0.266)
Result of the window 43:  wbwwre1 -> 0.980,    (Cricket -> 0.928)
Result of the window 44:  wbwwre1 -> 0.992,    (Cricket -> 0.945)
Result of the window 45:  chcant2 -> 0.571,    (Speech -> 0.158)
Result of the window 46:  chcant2 -> 0.787,    (Silence -> 0.844)
Mean result: wbwwre1 -> 0.7325854897499084

Exporting the model

The last step is exporting your model to be used on embedded devices or on the browser.

The export method export both formats for you.

models_path = './birds_models'
print(f'Exporing the TFLite model to {models_path}')

model.export(models_path, tflite_filename='my_birds_model.tflite')
Exporing the TFLite model to ./birds_models
2022-10-20 12:11:30.821235: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphxc1u85r/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphxc1u85r/assets
2022-10-20 12:11:36.989028: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format.
2022-10-20 12:11:36.989080: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency.
INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite

You can also export the SavedModel version for serving or using on a Python environment.

model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets
INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets
INFO:tensorflow:Saving labels in ./birds_models/labels.txt
INFO:tensorflow:Saving labels in ./birds_models/labels.txt

Next Steps

You did it.

Now your new model can be deployed on mobile devices using TFLite AudioClassifier Task API.

You can also try the same process with your own data with different classes and here is the documentation for Model Maker for Audio Classification.

Also learn from end-to-end reference apps: Android, iOS.