Attend the Women in ML Symposium on December 7 Register now

基于 TensorFlow Lite Model Maker 的音频域迁移学习

使用集合让一切井井有条 根据您的偏好保存内容并对其进行分类。

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本 查看 TF Hub 模型

在此 CoLab 笔记本中,您将学习如何使用 TensorFlow Lite Model Maker 来训练自定义音频分类模型。

Model Maker 库能够使用迁移学习来简化使用自定义数据集训练 TensorFlow Lite 模型的过程。使用您自己的自定义数据集重新训练 TensorFlow Lite 模型可以减少所需的训练数据量,并将缩短训练时间。

这是在 Android 上自定义并部署音频模型 Codelab 中的一部分。

您将使用一个自定义的鸟类数据集,并导出一个可在手机上使用的 TFLite 模型、一个可用于在浏览器中进行推断的 TensorFlow.JS 模型,以及一个可用于服务的 SavedModel 版本。

安装依赖项

sudo apt -y install libportaudio2
pip install tflite-model-maker

导入 TensorFlow、Model Maker 和其他库

在所需的依赖项中,您将使用 TensorFlow 和 Model Maker。除了这些,其他依赖项用于音频操作、播放和可视化。

import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import itertools
import glob
import random

from IPython.display import Audio, Image
from scipy.io import wavfile

print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")
TensorFlow Version: 2.9.1
Model Maker Version: 0.4.0

Birds 数据集

Birds 数据集是 5 种鸟类歌声的教育集合:

  • White-breasted Wood-Wren(白胸林鹩)
  • House Sparrow(家麻雀)
  • Red Crossbill(红交嘴雀)
  • Chestnut-crowned Antpitta(栗顶蚁鸫)
  • Azara's Spinetail(阿氏针尾雀)

原始音频来自 Xeno-canto,这是一个致力于分享世界各地鸟鸣的网站。

我们从下载数据开始。

birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',
                                                'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',
                                                cache_dir='./',
                                                cache_subdir='dataset',
                                                extract=True)
Downloading data from https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip
343680986/343680986 [==============================] - 3s 0us/step

探索数据

音频已被拆分为训练文件夹和测试文件夹。在每个拆分的文件夹中,每种鸟都有一个文件夹,使用它们的 bird_code 作为文件名。

音频均为单声道,采样率为 16 kHz。

有关每个文件的详细信息,请阅读 metadata.csv 文件。其中包含所有文件的作者、链接和一些详细信息。在本教程中,您不需要自己阅读它。

# @title [Run this] Util functions and data structures.

data_dir = './dataset/small_birds_dataset'

bird_code_to_name = {
  'wbwwre1': 'White-breasted Wood-Wren',
  'houspa': 'House Sparrow',
  'redcro': 'Red Crossbill',  
  'chcant2': 'Chestnut-crowned Antpitta',
  'azaspi1': "Azara's Spinetail",   
}

birds_images = {
  'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', #   Alejandro Bayer Tamayo from Armenia, Colombia 
  'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', #    Diliff
  'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', #  Elaine R. Wilson, www.naturespicsonline.com
  'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', #   Mike's Birds from Riverside, CA, US
  'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368
}

test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))

def get_random_audio_file():
  test_list = glob.glob(test_files)
  random_audio_path = random.choice(test_list)
  return random_audio_path


def show_bird_data(audio_path):
  sample_rate, audio_data = wavfile.read(audio_path, 'rb')

  bird_code = audio_path.split('/')[-2]
  print(f'Bird name: {bird_code_to_name[bird_code]}')
  print(f'Bird code: {bird_code}')
  display(Image(birds_images[bird_code]))

  plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
  plt.title(plttitle)
  plt.plot(audio_data)
  display(Audio(audio_data, rate=sample_rate))

print('functions and data structures created')
functions and data structures created

播放一些音频

为了更好地理解数据,我们来听一听测试拆分中的随机音频文件。

注:在本笔记本的后面部分,您将对此音频运行推断以进行测试

random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: Azara's Spinetail
Bird code: azaspi1

jpeg

png

训练模型

使用 Model Maker 制作音频时,必须从模型规范开始。这是基本模型,您的新模型将从中提取信息以学习新类。它还会影响如何转换数据集以符合模型规范参数,例如:采样率、通道数。

YAMNet 是在 AudioSet 数据集上训练的音频事件分类器,用于从 AudioSet 本体预测音频事件。

它的输入频率预计为 16 kHz,具有 1 个通道。

您无需自己进行任何重采样。Model Maker 会为您完成。

  • frame_length 用于确定每个训练样本的长度。在此示例中为 EXPECTED_WAVEFORM_LENGTH * 3s

  • frame_steps 用于确定训练样本之间的距离。在本例中,第 i 个样本将在第 (i-1) 个样本后的 EXPECTED_WAVEFORM_LENGTH * 6s 处开始。

设置这些值的原因是为了绕过现实世界数据集中的一些限制。

例如,在鸟类数据集中,鸟类并不总是唱歌。它们会唱歌,休息,然后再唱歌,中间会有噪音。拥有较长的帧将有助于捕捉歌声,但将其设置得太长会减少用于训练的样本数量。

spec = audio_classifier.YamNetSpec(
    keep_yamnet_and_custom_heads=True,
    frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
    frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)
INFO:tensorflow:Checkpoints are stored in /tmpfs/tmp/tmpo32sd7ga

加载数据

Model Maker 具有从文件夹加载数据并以模型规范的预期格式提供数据的 API。

训练拆分和测试拆分基于文件夹。验证数据集将被创建为训练拆分的 20%。

注:cache=True 对于提高之后的训练速度很重要,但它也需要更多的 RAM 来保存数据。对于 Birds 数据集,这不是问题,因为它只有 300MB,但如果您使用自己的数据,则必须加以注意。

train_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'test'), cache=True)

训练模型

audio_classifier 具有 create 方法,用于创建并开始训练模型。

您可以自定义许多参数,有关更多信息,请阅读文档中的更多详细信息。

在第一次尝试中,您将使用所有默认配置并训练 100 个周期。

注:第一个周期会比所有其他周期花费更长的时间,因为此时会创建缓存。之后,每一个周期花费近 1 秒。

batch_size = 128
epochs = 100

print('Training the model')
model = audio_classifier.create(
    train_data,
    spec,
    validation_data,
    batch_size=batch_size,
    epochs=epochs)
Training the model
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 classification_head (Dense)  (None, 5)                5125      
                                                                 
=================================================================
Total params: 5,125
Trainable params: 5,125
Non-trainable params: 0
_________________________________________________________________
Epoch 1/100
21/21 [==============================] - 19s 805ms/step - loss: 1.4962 - acc: 0.3230 - val_loss: 1.2149 - val_acc: 0.6796
Epoch 2/100
21/21 [==============================] - 0s 12ms/step - loss: 1.2849 - acc: 0.5033 - val_loss: 1.0718 - val_acc: 0.7058
Epoch 3/100
21/21 [==============================] - 0s 13ms/step - loss: 1.1563 - acc: 0.5890 - val_loss: 0.9997 - val_acc: 0.7662
Epoch 4/100
21/21 [==============================] - 0s 14ms/step - loss: 1.0463 - acc: 0.6567 - val_loss: 0.9374 - val_acc: 0.7902
Epoch 5/100
21/21 [==============================] - 0s 12ms/step - loss: 0.9721 - acc: 0.6817 - val_loss: 0.8918 - val_acc: 0.8084
Epoch 6/100
21/21 [==============================] - 0s 12ms/step - loss: 0.9137 - acc: 0.7067 - val_loss: 0.8554 - val_acc: 0.8164
Epoch 7/100
21/21 [==============================] - 0s 13ms/step - loss: 0.8674 - acc: 0.7205 - val_loss: 0.8218 - val_acc: 0.8255
Epoch 8/100
21/21 [==============================] - 0s 13ms/step - loss: 0.8287 - acc: 0.7297 - val_loss: 0.7996 - val_acc: 0.8084
Epoch 9/100
21/21 [==============================] - 0s 12ms/step - loss: 0.7939 - acc: 0.7436 - val_loss: 0.7799 - val_acc: 0.7959
Epoch 10/100
21/21 [==============================] - 0s 11ms/step - loss: 0.7605 - acc: 0.7559 - val_loss: 0.7614 - val_acc: 0.7868
Epoch 11/100
21/21 [==============================] - 0s 12ms/step - loss: 0.7357 - acc: 0.7670 - val_loss: 0.7482 - val_acc: 0.7731
Epoch 12/100
21/21 [==============================] - 0s 12ms/step - loss: 0.7088 - acc: 0.7701 - val_loss: 0.7333 - val_acc: 0.7651
Epoch 13/100
21/21 [==============================] - 0s 13ms/step - loss: 0.6868 - acc: 0.7778 - val_loss: 0.7232 - val_acc: 0.7594
Epoch 14/100
21/21 [==============================] - 0s 11ms/step - loss: 0.6713 - acc: 0.7905 - val_loss: 0.7127 - val_acc: 0.7537
Epoch 15/100
21/21 [==============================] - 0s 11ms/step - loss: 0.6535 - acc: 0.7839 - val_loss: 0.7071 - val_acc: 0.7423
Epoch 16/100
21/21 [==============================] - 0s 13ms/step - loss: 0.6312 - acc: 0.7974 - val_loss: 0.7011 - val_acc: 0.7343
Epoch 17/100
21/21 [==============================] - 0s 16ms/step - loss: 0.6159 - acc: 0.8147 - val_loss: 0.6943 - val_acc: 0.7286
Epoch 18/100
21/21 [==============================] - 0s 12ms/step - loss: 0.6044 - acc: 0.8151 - val_loss: 0.6923 - val_acc: 0.7218
Epoch 19/100
21/21 [==============================] - 0s 13ms/step - loss: 0.5876 - acc: 0.8193 - val_loss: 0.6850 - val_acc: 0.7184
Epoch 20/100
21/21 [==============================] - 0s 12ms/step - loss: 0.5810 - acc: 0.8224 - val_loss: 0.6806 - val_acc: 0.7184
Epoch 21/100
21/21 [==============================] - 0s 13ms/step - loss: 0.5721 - acc: 0.8181 - val_loss: 0.6784 - val_acc: 0.7206
Epoch 22/100
21/21 [==============================] - 0s 13ms/step - loss: 0.5659 - acc: 0.8205 - val_loss: 0.6742 - val_acc: 0.7229
Epoch 23/100
21/21 [==============================] - 0s 13ms/step - loss: 0.5439 - acc: 0.8366 - val_loss: 0.6706 - val_acc: 0.7206
Epoch 24/100
21/21 [==============================] - 0s 11ms/step - loss: 0.5451 - acc: 0.8331 - val_loss: 0.6727 - val_acc: 0.7161
Epoch 25/100
21/21 [==============================] - 0s 11ms/step - loss: 0.5348 - acc: 0.8354 - val_loss: 0.6688 - val_acc: 0.7195
Epoch 26/100
21/21 [==============================] - 0s 13ms/step - loss: 0.5315 - acc: 0.8354 - val_loss: 0.6693 - val_acc: 0.7172
Epoch 27/100
21/21 [==============================] - 0s 14ms/step - loss: 0.5148 - acc: 0.8435 - val_loss: 0.6681 - val_acc: 0.7161
Epoch 28/100
21/21 [==============================] - 0s 11ms/step - loss: 0.5165 - acc: 0.8385 - val_loss: 0.6661 - val_acc: 0.7172
Epoch 29/100
21/21 [==============================] - 0s 12ms/step - loss: 0.5062 - acc: 0.8420 - val_loss: 0.6618 - val_acc: 0.7172
Epoch 30/100
21/21 [==============================] - 0s 14ms/step - loss: 0.4978 - acc: 0.8389 - val_loss: 0.6630 - val_acc: 0.7149
Epoch 31/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4925 - acc: 0.8512 - val_loss: 0.6670 - val_acc: 0.7115
Epoch 32/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4777 - acc: 0.8516 - val_loss: 0.6626 - val_acc: 0.7149
Epoch 33/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4887 - acc: 0.8458 - val_loss: 0.6643 - val_acc: 0.7138
Epoch 34/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4764 - acc: 0.8535 - val_loss: 0.6627 - val_acc: 0.7138
Epoch 35/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4691 - acc: 0.8531 - val_loss: 0.6661 - val_acc: 0.7115
Epoch 36/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4617 - acc: 0.8639 - val_loss: 0.6625 - val_acc: 0.7115
Epoch 37/100
21/21 [==============================] - 0s 11ms/step - loss: 0.4634 - acc: 0.8551 - val_loss: 0.6595 - val_acc: 0.7149
Epoch 38/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4556 - acc: 0.8581 - val_loss: 0.6635 - val_acc: 0.7092
Epoch 39/100
21/21 [==============================] - 0s 14ms/step - loss: 0.4539 - acc: 0.8562 - val_loss: 0.6586 - val_acc: 0.7138
Epoch 40/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4456 - acc: 0.8651 - val_loss: 0.6580 - val_acc: 0.7138
Epoch 41/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4488 - acc: 0.8581 - val_loss: 0.6625 - val_acc: 0.7138
Epoch 42/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4400 - acc: 0.8662 - val_loss: 0.6663 - val_acc: 0.7104
Epoch 43/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4358 - acc: 0.8693 - val_loss: 0.6630 - val_acc: 0.7092
Epoch 44/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4312 - acc: 0.8685 - val_loss: 0.6606 - val_acc: 0.7115
Epoch 45/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4296 - acc: 0.8716 - val_loss: 0.6667 - val_acc: 0.7058
Epoch 46/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4183 - acc: 0.8697 - val_loss: 0.6571 - val_acc: 0.7138
Epoch 47/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4231 - acc: 0.8727 - val_loss: 0.6603 - val_acc: 0.7138
Epoch 48/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4116 - acc: 0.8727 - val_loss: 0.6650 - val_acc: 0.7115
Epoch 49/100
21/21 [==============================] - 0s 12ms/step - loss: 0.4191 - acc: 0.8720 - val_loss: 0.6620 - val_acc: 0.7115
Epoch 50/100
21/21 [==============================] - 0s 11ms/step - loss: 0.4076 - acc: 0.8793 - val_loss: 0.6676 - val_acc: 0.7070
Epoch 51/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4059 - acc: 0.8762 - val_loss: 0.6679 - val_acc: 0.7058
Epoch 52/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4015 - acc: 0.8758 - val_loss: 0.6654 - val_acc: 0.7070
Epoch 53/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4113 - acc: 0.8670 - val_loss: 0.6664 - val_acc: 0.7058
Epoch 54/100
21/21 [==============================] - 0s 13ms/step - loss: 0.4038 - acc: 0.8735 - val_loss: 0.6765 - val_acc: 0.7058
Epoch 55/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3982 - acc: 0.8800 - val_loss: 0.6735 - val_acc: 0.7058
Epoch 56/100
21/21 [==============================] - 0s 14ms/step - loss: 0.3925 - acc: 0.8735 - val_loss: 0.6768 - val_acc: 0.7047
Epoch 57/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3877 - acc: 0.8777 - val_loss: 0.6714 - val_acc: 0.7058
Epoch 58/100
21/21 [==============================] - 0s 15ms/step - loss: 0.3987 - acc: 0.8747 - val_loss: 0.6818 - val_acc: 0.7035
Epoch 59/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3891 - acc: 0.8762 - val_loss: 0.6782 - val_acc: 0.7058
Epoch 60/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3909 - acc: 0.8716 - val_loss: 0.6776 - val_acc: 0.7058
Epoch 61/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3913 - acc: 0.8827 - val_loss: 0.6830 - val_acc: 0.7035
Epoch 62/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3781 - acc: 0.8889 - val_loss: 0.6849 - val_acc: 0.7024
Epoch 63/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3827 - acc: 0.8724 - val_loss: 0.6837 - val_acc: 0.7035
Epoch 64/100
21/21 [==============================] - 0s 11ms/step - loss: 0.3856 - acc: 0.8762 - val_loss: 0.6901 - val_acc: 0.7024
Epoch 65/100
21/21 [==============================] - 0s 14ms/step - loss: 0.3726 - acc: 0.8827 - val_loss: 0.6908 - val_acc: 0.7024
Epoch 66/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3748 - acc: 0.8762 - val_loss: 0.6839 - val_acc: 0.7047
Epoch 67/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3765 - acc: 0.8866 - val_loss: 0.6855 - val_acc: 0.7035
Epoch 68/100
21/21 [==============================] - 0s 11ms/step - loss: 0.3707 - acc: 0.8897 - val_loss: 0.6868 - val_acc: 0.7058
Epoch 69/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3687 - acc: 0.8835 - val_loss: 0.6927 - val_acc: 0.7024
Epoch 70/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3703 - acc: 0.8824 - val_loss: 0.7016 - val_acc: 0.7047
Epoch 71/100
21/21 [==============================] - 0s 14ms/step - loss: 0.3663 - acc: 0.8843 - val_loss: 0.6936 - val_acc: 0.7047
Epoch 72/100
21/21 [==============================] - 0s 11ms/step - loss: 0.3622 - acc: 0.8858 - val_loss: 0.7080 - val_acc: 0.7035
Epoch 73/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3574 - acc: 0.8912 - val_loss: 0.7048 - val_acc: 0.7047
Epoch 74/100
21/21 [==============================] - 0s 11ms/step - loss: 0.3651 - acc: 0.8816 - val_loss: 0.6980 - val_acc: 0.7024
Epoch 75/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3580 - acc: 0.8900 - val_loss: 0.7062 - val_acc: 0.7035
Epoch 76/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3535 - acc: 0.8877 - val_loss: 0.7155 - val_acc: 0.7035
Epoch 77/100
21/21 [==============================] - 0s 11ms/step - loss: 0.3559 - acc: 0.8904 - val_loss: 0.7099 - val_acc: 0.7070
Epoch 78/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3549 - acc: 0.8870 - val_loss: 0.7055 - val_acc: 0.7070
Epoch 79/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3552 - acc: 0.8862 - val_loss: 0.7049 - val_acc: 0.7058
Epoch 80/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3492 - acc: 0.8847 - val_loss: 0.7150 - val_acc: 0.7058
Epoch 81/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3480 - acc: 0.8916 - val_loss: 0.7075 - val_acc: 0.7058
Epoch 82/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3484 - acc: 0.8912 - val_loss: 0.7089 - val_acc: 0.7058
Epoch 83/100
21/21 [==============================] - 0s 14ms/step - loss: 0.3438 - acc: 0.8939 - val_loss: 0.7106 - val_acc: 0.7070
Epoch 84/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3492 - acc: 0.8866 - val_loss: 0.7172 - val_acc: 0.7081
Epoch 85/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3389 - acc: 0.8923 - val_loss: 0.7150 - val_acc: 0.7081
Epoch 86/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3429 - acc: 0.8858 - val_loss: 0.7230 - val_acc: 0.7081
Epoch 87/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3437 - acc: 0.8904 - val_loss: 0.7251 - val_acc: 0.7081
Epoch 88/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3409 - acc: 0.8931 - val_loss: 0.7194 - val_acc: 0.7092
Epoch 89/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3380 - acc: 0.8850 - val_loss: 0.7271 - val_acc: 0.7081
Epoch 90/100
21/21 [==============================] - 0s 11ms/step - loss: 0.3420 - acc: 0.8881 - val_loss: 0.7180 - val_acc: 0.7070
Epoch 91/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3376 - acc: 0.8954 - val_loss: 0.7335 - val_acc: 0.7081
Epoch 92/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3365 - acc: 0.8870 - val_loss: 0.7347 - val_acc: 0.7058
Epoch 93/100
21/21 [==============================] - 0s 11ms/step - loss: 0.3360 - acc: 0.8923 - val_loss: 0.7172 - val_acc: 0.7092
Epoch 94/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3312 - acc: 0.8908 - val_loss: 0.7342 - val_acc: 0.7081
Epoch 95/100
21/21 [==============================] - 0s 11ms/step - loss: 0.3377 - acc: 0.8904 - val_loss: 0.7291 - val_acc: 0.7092
Epoch 96/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3329 - acc: 0.8966 - val_loss: 0.7362 - val_acc: 0.7070
Epoch 97/100
21/21 [==============================] - 0s 13ms/step - loss: 0.3288 - acc: 0.9008 - val_loss: 0.7318 - val_acc: 0.7081
Epoch 98/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3301 - acc: 0.8897 - val_loss: 0.7399 - val_acc: 0.7081
Epoch 99/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3215 - acc: 0.8973 - val_loss: 0.7356 - val_acc: 0.7092
Epoch 100/100
21/21 [==============================] - 0s 12ms/step - loss: 0.3264 - acc: 0.8958 - val_loss: 0.7370 - val_acc: 0.7081

准确率看起来很好,但重要的是对测试数据运行评估步骤,并验证您的模型是否能够在非种子数据上取得良好的结果。

print('Evaluating the model')
model.evaluate(test_data)
Evaluating the model
28/28 [==============================] - 5s 145ms/step - loss: 0.8493 - acc: 0.7692
[0.8492956757545471, 0.7692307829856873]

理解模型

训练分类器时,查看混淆矩阵非常实用。混淆矩阵可帮助您详细了解分类器在测试数据上的性能。

Model Maker 已经为您创建了混淆矩阵。

def show_confusion_matrix(confusion, test_labels):
  """Compute confusion matrix and normalize."""
  confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
  axis_labels = test_labels
  ax = sns.heatmap(
      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
      cmap='Blues', annot=True, fmt='.2f', square=True)
  plt.title("Confusion matrix")
  plt.ylabel("True label")
  plt.xlabel("Predicted label")

confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)
1/1 [==============================] - 0s 158ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 47ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 53ms/step
1/1 [==============================] - 0s 54ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 62ms/step
1/1 [==============================] - 0s 54ms/step
1/1 [==============================] - 0s 57ms/step
1/1 [==============================] - 0s 56ms/step
1/1 [==============================] - 0s 49ms/step
1/1 [==============================] - 0s 35ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 50ms/step
1/1 [==============================] - 0s 54ms/step
1/1 [==============================] - 0s 51ms/step
1/1 [==============================] - 0s 58ms/step
1/1 [==============================] - 0s 56ms/step
1/1 [==============================] - 0s 65ms/step
1/1 [==============================] - 0s 57ms/step
1/1 [==============================] - 0s 40ms/step
1/1 [==============================] - 0s 38ms/step
1/1 [==============================] - 0s 22ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 42ms/step

png

测试模型 [可选]

您可以使用测试数据集中的样本音频试用该模型,以查看结果。

首先,您获得应用模型。

serving_model = model.create_serving_model()

print(f'Model\'s input shape and type: {serving_model.inputs}')
print(f'Model\'s output shape and type: {serving_model.outputs}')
Model's input shape and type: [<KerasTensor: shape=(None, 15600) dtype=float32 (created by layer 'audio')>]
Model's output shape and type: [<KerasTensor: shape=(None, 521) dtype=float32 (created by layer 'keras_layer')>, <KerasTensor: shape=(None, 5) dtype=float32 (created by layer 'sequential')>]

回到您之前加载的随机音频

# if you want to try another file just uncoment the line below
random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: House Sparrow
Bird code: houspa

jpeg

png

创建的模型具有固定的输入窗口。

对于给定的音频文件,您必须将其拆分成预期大小的数据窗口。最后一个窗口可能需要用零填充。

sample_rate, audio_data = wavfile.read(random_audio, 'rb')

audio_data = np.array(audio_data) / tf.int16.max
input_size = serving_model.input_shape[1]

splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)

print(f'Test audio path: {random_audio}')
print(f'Original size of the audio data: {len(audio_data)}')
print(f'Number of windows for inference: {len(splitted_audio_data)}')
Test audio path: /tmpfs/src/temp/site/zh-cn/lite/models/modify/model_maker/dataset/small_birds_dataset/test/houspa/XC564822.wav
Original size of the audio data: 863968
Number of windows for inference: 56

您将循环遍历所有拆分的音频,并为每个音频应用模型。

您刚刚训练的模型有两个输出:原始 YAMNet 的输出和您刚刚训练的输出。这一点很重要,因为现实世界的环境比鸟鸣要复杂得多。您可以使用 YAMNet 的输出过滤掉不相关的音频,例如,在鸟类用例中,如果 YAMNet 没有对 Birds 或 Animals 进行分类,这可能表明您的模型的输出可能具有不相关的分类。

下面打印了两个输出,以便于理解它们之间的关系。您的模型犯错的大多数时候是当 YAMNet 的预测与您的领域不相关时(例如:鸟类)。

print(random_audio)

results = []
print('Result of the window ith:  your model class -> score,  (spec class -> score)')
for i, data in enumerate(splitted_audio_data):
  yamnet_output, inference = serving_model(data)
  results.append(inference[0].numpy())
  result_index = tf.argmax(inference[0])
  spec_result_index = tf.argmax(yamnet_output[0])
  t = spec._yamnet_labels()[spec_result_index]
  result_str = f'Result of the window {i}: ' \
  f'\t{test_data.index_to_label[result_index]} -> {inference[0][result_index].numpy():.3f}, ' \
  f'\t({spec._yamnet_labels()[spec_result_index]} -> {yamnet_output[0][spec_result_index]:.3f})'
  print(result_str)


results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'Mean result: {test_data.index_to_label[result_index]} -> {mean_results[result_index]}')
/tmpfs/src/temp/site/zh-cn/lite/models/modify/model_maker/dataset/small_birds_dataset/test/houspa/XC564822.wav
Result of the window ith:  your model class -> score,  (spec class -> score)
Result of the window 0:   houspa -> 0.900,     (Bird -> 0.916)
Result of the window 1:   houspa -> 0.860,     (Bird -> 0.812)
Result of the window 2:   houspa -> 0.576,     (Wild animals -> 0.858)
Result of the window 3:   houspa -> 0.889,     (Bird -> 0.956)
Result of the window 4:   houspa -> 0.933,     (Bird vocalization, bird call, bird song -> 0.981)
Result of the window 5:   houspa -> 0.922,     (Bird -> 0.954)
Result of the window 6:   houspa -> 0.773,     (Animal -> 0.953)
Result of the window 7:   houspa -> 0.946,     (Bird -> 0.967)
Result of the window 8:   houspa -> 0.757,     (Bird vocalization, bird call, bird song -> 0.603)
Result of the window 9:   houspa -> 0.857,     (Bird vocalization, bird call, bird song -> 0.885)
Result of the window 10:  houspa -> 0.877,     (Bird -> 0.853)
Result of the window 11:  houspa -> 0.920,     (Bird -> 0.915)
Result of the window 12:  houspa -> 0.779,     (Wild animals -> 0.947)
Result of the window 13:  houspa -> 0.786,     (Bird -> 0.852)
Result of the window 14:  houspa -> 0.616,     (Bird vocalization, bird call, bird song -> 0.928)
Result of the window 15:  houspa -> 0.994,     (Bird -> 0.999)
Result of the window 16:  houspa -> 0.891,     (Bird vocalization, bird call, bird song -> 0.967)
Result of the window 17:  houspa -> 0.969,     (Bird -> 0.953)
Result of the window 18:  houspa -> 0.712,     (Bird -> 0.898)
Result of the window 19:  redcro -> 0.574,     (Wild animals -> 0.952)
Result of the window 20:  houspa -> 0.517,     (Bird vocalization, bird call, bird song -> 0.974)
Result of the window 21:  redcro -> 0.589,     (Environmental noise -> 0.503)
Result of the window 22:  houspa -> 0.745,     (Bird -> 0.897)
Result of the window 23:  wbwwre1 -> 0.574,    (Wild animals -> 0.974)
Result of the window 24:  houspa -> 0.983,     (Bird -> 0.876)
Result of the window 25:  chcant2 -> 0.946,    (Silence -> 0.998)
Result of the window 26:  houspa -> 0.806,     (Wild animals -> 0.981)
Result of the window 27:  houspa -> 0.959,     (Bird -> 0.980)
Result of the window 28:  houspa -> 0.879,     (Bird -> 0.983)
Result of the window 29:  houspa -> 0.985,     (Bird -> 0.986)
Result of the window 30:  houspa -> 0.997,     (Bird -> 0.998)
Result of the window 31:  houspa -> 0.968,     (Bird -> 0.992)
Result of the window 32:  redcro -> 0.836,     (Environmental noise -> 0.669)
Result of the window 33:  houspa -> 0.647,     (Bird vocalization, bird call, bird song -> 0.855)
Result of the window 34:  houspa -> 0.397,     (Wild animals -> 0.927)
Result of the window 35:  houspa -> 0.952,     (Bird -> 0.955)
Result of the window 36:  houspa -> 0.819,     (Wild animals -> 0.969)
Result of the window 37:  redcro -> 0.342,     (Bird vocalization, bird call, bird song -> 0.852)
Result of the window 38:  houspa -> 0.594,     (Bird vocalization, bird call, bird song -> 0.496)
Result of the window 39:  houspa -> 0.811,     (Bird -> 0.982)
Result of the window 40:  wbwwre1 -> 0.839,    (Wild animals -> 0.986)
Result of the window 41:  houspa -> 0.985,     (Bird -> 0.973)
Result of the window 42:  wbwwre1 -> 0.616,    (Wild animals -> 0.994)
Result of the window 43:  houspa -> 0.948,     (Bird -> 0.987)
Result of the window 44:  redcro -> 0.545,     (Wild animals -> 0.902)
Result of the window 45:  houspa -> 0.974,     (Wild animals -> 0.980)
Result of the window 46:  houspa -> 0.979,     (Bird -> 0.995)
Result of the window 47:  houspa -> 0.977,     (Bird -> 0.977)
Result of the window 48:  redcro -> 0.943,     (Bird vocalization, bird call, bird song -> 0.851)
Result of the window 49:  houspa -> 0.892,     (Wild animals -> 0.985)
Result of the window 50:  houspa -> 0.984,     (Bird -> 0.980)
Result of the window 51:  houspa -> 0.980,     (Bird -> 0.969)
Result of the window 52:  houspa -> 0.968,     (Bird -> 0.977)
Result of the window 53:  houspa -> 0.920,     (Animal -> 0.996)
Result of the window 54:  houspa -> 0.809,     (Bird vocalization, bird call, bird song -> 0.708)
Result of the window 55:  houspa -> 0.634,     (Bird -> 0.962)
Mean result: houspa -> 0.7262401580810547

导出模型

最后一步是导出要在嵌入式设备或浏览器上使用的模型。

export 方法能够为您导出这两种格式。

models_path = './birds_models'
print(f'Exporing the TFLite model to {models_path}')

model.export(models_path, tflite_filename='my_birds_model.tflite')
Exporing the TFLite model to ./birds_models
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmph4szi63q/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmph4szi63q/assets
2022-08-11 17:49:45.348759: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-08-11 17:49:45.348875: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite

您还可以导出 SavedModel 版本,以便在 Python 环境中应用或使用。

model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets
INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets
INFO:tensorflow:Saving labels in ./birds_models/labels.txt
INFO:tensorflow:Saving labels in ./birds_models/labels.txt

后续步骤

您成功了。

现在,您的新模型可以使用 TFLite AudioClassifier Task API 部署在移动设备上。

您还可以使用具有不同类的您自己的数据尝试相同的过程,这里是用于音频分类的 Model Maker 的文档。

您还可以从端到端参考应用中学习:AndroidiOS