Hari Komunitas ML adalah 9 November! Bergabung dengan kami untuk update dari TensorFlow, JAX, dan lebih Pelajari lebih lanjut

Mentransfer Pembelajaran untuk Domain Audio dengan Pembuat Model TensorFlow Lite

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan Lihat model Hub TF

Dalam notebook colab ini, Anda akan belajar bagaimana menggunakan pembuat TensorFlow Lite Model untuk melatih model klasifikasi audio yang kustom.

Library Model Maker menggunakan pembelajaran transfer untuk menyederhanakan proses pelatihan model TensorFlow Lite menggunakan set data kustom. Melatih ulang model TensorFlow Lite dengan set data kustom Anda sendiri akan mengurangi jumlah data pelatihan dan waktu yang diperlukan.

Ini adalah bagian dari Codelab untuk Customize model Audio dan menyebarkan pada Android .

Anda akan menggunakan kumpulan data burung khusus dan mengekspor model TFLite yang dapat digunakan di ponsel, model TensorFlow.JS yang dapat digunakan untuk inferensi di browser, dan juga versi SavedModel yang dapat Anda gunakan untuk penyajian.

Menginstal dependensi

 pip install tflite-model-maker

Impor TensorFlow, Model Maker, dan library lainnya

Di antara dependensi yang diperlukan, Anda akan menggunakan TensorFlow dan Model Maker. Selain itu, yang lainnya adalah untuk manipulasi audio, pemutaran dan visualisasi.

import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import itertools
import glob
import random

from IPython.display import Audio, Image
from scipy.io import wavfile

print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/numba/core/errors.py:168: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
  warnings.warn(msg)
TensorFlow Version: 2.6.0
Model Maker Version: 0.3.2

Kumpulan data Burung

Dataset Burung adalah kumpulan edukasi 5 jenis kicau burung:

  • Gelatik Kayu Berdada Putih
  • Burung pipit rumah
  • Palang Merah
  • Antpitta bermahkota kastanye
  • Ekor Tulang Azara

Audio asli berasal dari Xeno-Canto yang merupakan situs yang didedikasikan untuk berbagi suara burung dari seluruh dunia.

Mari kita mulai dengan mengunduh data.

birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',
                                                'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',
                                                cache_dir='./',
                                                cache_subdir='dataset',
                                                extract=True)
Downloading data from https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip
343687168/343680986 [==============================] - 6s 0us/step
343695360/343680986 [==============================] - 6s 0us/step

Jelajahi datanya

Audio sudah dibagi dalam folder train dan test. Di dalam setiap folder split, ada satu folder untuk setiap burung, menggunakan mereka bird_code sebagai nama.

Audio semuanya mono dan dengan sample rate 16kHz.

Untuk informasi lebih lanjut tentang setiap file, Anda dapat membaca metadata.csv berkas. Ini berisi semua penulis file, lincenses dan beberapa informasi lebih lanjut. Anda tidak perlu membacanya sendiri di tutorial ini.

# @title [Run this] Util functions and data structures.

data_dir = './dataset/small_birds_dataset'

bird_code_to_name = {
  'wbwwre1': 'White-breasted Wood-Wren',
  'houspa': 'House Sparrow',
  'redcro': 'Red Crossbill',  
  'chcant2': 'Chestnut-crowned Antpitta',
  'azaspi1': "Azara's Spinetail",   
}

birds_images = {
  'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', #   Alejandro Bayer Tamayo from Armenia, Colombia 
  'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', #    Diliff
  'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', #  Elaine R. Wilson, www.naturespicsonline.com
  'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', #   Mike's Birds from Riverside, CA, US
  'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368
}

test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))

def get_random_audio_file():
  test_list = glob.glob(test_files)
  random_audio_path = random.choice(test_list)
  return random_audio_path


def show_bird_data(audio_path):
  sample_rate, audio_data = wavfile.read(audio_path, 'rb')

  bird_code = audio_path.split('/')[-2]
  print(f'Bird name: {bird_code_to_name[bird_code]}')
  print(f'Bird code: {bird_code}')
  display(Image(birds_images[bird_code]))

  plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
  plt.title(plttitle)
  plt.plot(audio_data)
  display(Audio(audio_data, rate=sample_rate))

print('functions and data structures created')
functions and data structures created

Memutar beberapa audio

Untuk memiliki pemahaman yang lebih baik tentang data, mari dengarkan file audio acak dari pengujian split.

random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: White-breasted Wood-Wren
Bird code: wbwwre1

jpeg

png

Melatih Model

Saat menggunakan Model Maker untuk audio, Anda harus mulai dengan spesifikasi model. Ini adalah model dasar yang model baru Anda akan mengekstrak informasi untuk mempelajari tentang kelas baru. Ini juga mempengaruhi bagaimana dataset akan diubah untuk menghormati parameter spesifikasi model seperti: sample rate, jumlah saluran.

YAMNet merupakan acara classifier audio yang dilatih di AudioSet dataset untuk memprediksi kejadian audio dari ontologi AudioSet.

Inputnya diharapkan pada 16kHz dan dengan 1 saluran.

Anda tidak perlu melakukan resampling sendiri. Model Maker menanganinya untuk Anda.

  • frame_length adalah untuk memutuskan berapa lama masing-masing sampel traininng adalah. dalam hal ini EXPECTED_WAVEFORM_LENGTH * 3 detik

  • frame_steps adalah untuk memutuskan seberapa jauh appart adalah sampel pelatihan. Dalam hal ini, sampel ke-i akan dimulai pada EXPECTED_WAVEFORM_LENGTH * 6s setelah sampel ke-i-1).

Alasan untuk menetapkan nilai-nilai ini adalah untuk mengatasi beberapa batasan dalam kumpulan data dunia nyata.

Misalnya, dalam kumpulan data burung, burung tidak bernyanyi sepanjang waktu. Mereka bernyanyi, beristirahat dan bernyanyi lagi, dengan suara-suara di antaranya. Memiliki bingkai yang panjang akan membantu menangkap nyanyian, tetapi menyetelnya terlalu lama akan mengurangi jumlah sampel untuk pelatihan.

spec = audio_classifier.YamNetSpec(
    keep_yamnet_and_custom_heads=True,
    frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
    frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)
INFO:tensorflow:Checkpoints are stored in /tmp/tmp3s72bspo

Memuat data

Model Maker memiliki API untuk memuat data dari folder dan memilikinya dalam format yang diharapkan untuk spesifikasi model.

Pemisahan kereta dan pengujian didasarkan pada folder. Dataset validasi akan dibuat sebagai 20% dari train split.

train_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'test'), cache=True)

Melatih model

audio_classifier memiliki create metode yang menciptakan model dan sudah mulai latihan itu.

Anda dapat menyesuaikan banyak parameter, untuk informasi lebih lanjut Anda dapat membaca detail selengkapnya di dokumentasi.

Pada percobaan pertama ini, Anda akan menggunakan semua konfigurasi default dan berlatih selama 100 epoch.

batch_size = 128
epochs = 100

print('Training the model')
model = audio_classifier.create(
    train_data,
    spec,
    validation_data,
    batch_size=batch_size,
    epochs=epochs)
Training the model
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
classification_head (Dense)  (None, 5)                 5125      
=================================================================
Total params: 5,125
Trainable params: 5,125
Non-trainable params: 0
_________________________________________________________________
Epoch 1/100
22/22 [==============================] - 54s 2s/step - loss: 1.5431 - acc: 0.3074 - val_loss: 1.3118 - val_acc: 0.4735
Epoch 2/100
22/22 [==============================] - 1s 25ms/step - loss: 1.3122 - acc: 0.4780 - val_loss: 1.1328 - val_acc: 0.5794
Epoch 3/100
22/22 [==============================] - 1s 25ms/step - loss: 1.1547 - acc: 0.5975 - val_loss: 1.0275 - val_acc: 0.6374
Epoch 4/100
22/22 [==============================] - 1s 24ms/step - loss: 1.0504 - acc: 0.6419 - val_loss: 0.9516 - val_acc: 0.6890
Epoch 5/100
22/22 [==============================] - 1s 24ms/step - loss: 0.9666 - acc: 0.6944 - val_loss: 0.8896 - val_acc: 0.7303
Epoch 6/100
22/22 [==============================] - 1s 24ms/step - loss: 0.9079 - acc: 0.7159 - val_loss: 0.8458 - val_acc: 0.7432
Epoch 7/100
22/22 [==============================] - 1s 25ms/step - loss: 0.8554 - acc: 0.7395 - val_loss: 0.8056 - val_acc: 0.7639
Epoch 8/100
22/22 [==============================] - 1s 25ms/step - loss: 0.8174 - acc: 0.7525 - val_loss: 0.7809 - val_acc: 0.7665
Epoch 9/100
22/22 [==============================] - 1s 24ms/step - loss: 0.7840 - acc: 0.7555 - val_loss: 0.7554 - val_acc: 0.7781
Epoch 10/100
22/22 [==============================] - 1s 24ms/step - loss: 0.7447 - acc: 0.7721 - val_loss: 0.7342 - val_acc: 0.7794
Epoch 11/100
22/22 [==============================] - 1s 25ms/step - loss: 0.7215 - acc: 0.7773 - val_loss: 0.7145 - val_acc: 0.7871
Epoch 12/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6953 - acc: 0.7932 - val_loss: 0.7034 - val_acc: 0.7897
Epoch 13/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6811 - acc: 0.7888 - val_loss: 0.6884 - val_acc: 0.7961
Epoch 14/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6605 - acc: 0.7987 - val_loss: 0.6744 - val_acc: 0.7974
Epoch 15/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6435 - acc: 0.7999 - val_loss: 0.6638 - val_acc: 0.7974
Epoch 16/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6244 - acc: 0.8106 - val_loss: 0.6524 - val_acc: 0.7987
Epoch 17/100
22/22 [==============================] - 1s 24ms/step - loss: 0.6080 - acc: 0.8232 - val_loss: 0.6443 - val_acc: 0.7987
Epoch 18/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5917 - acc: 0.8202 - val_loss: 0.6336 - val_acc: 0.8013
Epoch 19/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5825 - acc: 0.8187 - val_loss: 0.6272 - val_acc: 0.8000
Epoch 20/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5776 - acc: 0.8287 - val_loss: 0.6182 - val_acc: 0.8052
Epoch 21/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5671 - acc: 0.8328 - val_loss: 0.6142 - val_acc: 0.8052
Epoch 22/100
22/22 [==============================] - 1s 23ms/step - loss: 0.5542 - acc: 0.8346 - val_loss: 0.6131 - val_acc: 0.8026
Epoch 23/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5474 - acc: 0.8372 - val_loss: 0.6074 - val_acc: 0.8039
Epoch 24/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5343 - acc: 0.8435 - val_loss: 0.6005 - val_acc: 0.8090
Epoch 25/100
22/22 [==============================] - 1s 24ms/step - loss: 0.5324 - acc: 0.8376 - val_loss: 0.5926 - val_acc: 0.8103
Epoch 26/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5225 - acc: 0.8428 - val_loss: 0.5878 - val_acc: 0.8103
Epoch 27/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5215 - acc: 0.8431 - val_loss: 0.5848 - val_acc: 0.8116
Epoch 28/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5120 - acc: 0.8420 - val_loss: 0.5819 - val_acc: 0.8116
Epoch 29/100
22/22 [==============================] - 1s 25ms/step - loss: 0.5023 - acc: 0.8465 - val_loss: 0.5796 - val_acc: 0.8155
Epoch 30/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4998 - acc: 0.8550 - val_loss: 0.5741 - val_acc: 0.8155
Epoch 31/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4994 - acc: 0.8509 - val_loss: 0.5722 - val_acc: 0.8142
Epoch 32/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4880 - acc: 0.8505 - val_loss: 0.5688 - val_acc: 0.8142
Epoch 33/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4851 - acc: 0.8513 - val_loss: 0.5615 - val_acc: 0.8181
Epoch 34/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4847 - acc: 0.8472 - val_loss: 0.5584 - val_acc: 0.8181
Epoch 35/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4721 - acc: 0.8609 - val_loss: 0.5584 - val_acc: 0.8194
Epoch 36/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4668 - acc: 0.8524 - val_loss: 0.5575 - val_acc: 0.8219
Epoch 37/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4682 - acc: 0.8531 - val_loss: 0.5530 - val_acc: 0.8232
Epoch 38/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4615 - acc: 0.8576 - val_loss: 0.5504 - val_acc: 0.8232
Epoch 39/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4636 - acc: 0.8535 - val_loss: 0.5526 - val_acc: 0.8219
Epoch 40/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4585 - acc: 0.8561 - val_loss: 0.5500 - val_acc: 0.8206
Epoch 41/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4480 - acc: 0.8613 - val_loss: 0.5412 - val_acc: 0.8206
Epoch 42/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4573 - acc: 0.8587 - val_loss: 0.5411 - val_acc: 0.8194
Epoch 43/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4523 - acc: 0.8565 - val_loss: 0.5362 - val_acc: 0.8194
Epoch 44/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4367 - acc: 0.8635 - val_loss: 0.5403 - val_acc: 0.8206
Epoch 45/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4318 - acc: 0.8631 - val_loss: 0.5366 - val_acc: 0.8194
Epoch 46/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4263 - acc: 0.8679 - val_loss: 0.5349 - val_acc: 0.8194
Epoch 47/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4430 - acc: 0.8572 - val_loss: 0.5362 - val_acc: 0.8194
Epoch 48/100
22/22 [==============================] - 1s 24ms/step - loss: 0.4332 - acc: 0.8646 - val_loss: 0.5281 - val_acc: 0.8219
Epoch 49/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4229 - acc: 0.8687 - val_loss: 0.5304 - val_acc: 0.8206
Epoch 50/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4287 - acc: 0.8627 - val_loss: 0.5257 - val_acc: 0.8206
Epoch 51/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4119 - acc: 0.8824 - val_loss: 0.5283 - val_acc: 0.8181
Epoch 52/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4271 - acc: 0.8653 - val_loss: 0.5233 - val_acc: 0.8206
Epoch 53/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4050 - acc: 0.8676 - val_loss: 0.5226 - val_acc: 0.8206
Epoch 54/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4003 - acc: 0.8768 - val_loss: 0.5142 - val_acc: 0.8245
Epoch 55/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4087 - acc: 0.8775 - val_loss: 0.5170 - val_acc: 0.8206
Epoch 56/100
22/22 [==============================] - 1s 26ms/step - loss: 0.4029 - acc: 0.8731 - val_loss: 0.5187 - val_acc: 0.8194
Epoch 57/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4063 - acc: 0.8672 - val_loss: 0.5136 - val_acc: 0.8219
Epoch 58/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3998 - acc: 0.8787 - val_loss: 0.5132 - val_acc: 0.8245
Epoch 59/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3911 - acc: 0.8798 - val_loss: 0.5113 - val_acc: 0.8245
Epoch 60/100
22/22 [==============================] - 1s 25ms/step - loss: 0.4022 - acc: 0.8742 - val_loss: 0.5044 - val_acc: 0.8284
Epoch 61/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3910 - acc: 0.8757 - val_loss: 0.5090 - val_acc: 0.8245
Epoch 62/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3869 - acc: 0.8798 - val_loss: 0.5129 - val_acc: 0.8245
Epoch 63/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3880 - acc: 0.8809 - val_loss: 0.5070 - val_acc: 0.8271
Epoch 64/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3905 - acc: 0.8757 - val_loss: 0.5047 - val_acc: 0.8284
Epoch 65/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3886 - acc: 0.8801 - val_loss: 0.5094 - val_acc: 0.8271
Epoch 66/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3877 - acc: 0.8746 - val_loss: 0.5037 - val_acc: 0.8284
Epoch 67/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3846 - acc: 0.8775 - val_loss: 0.5022 - val_acc: 0.8258
Epoch 68/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3890 - acc: 0.8831 - val_loss: 0.5015 - val_acc: 0.8284
Epoch 69/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3861 - acc: 0.8731 - val_loss: 0.5001 - val_acc: 0.8284
Epoch 70/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3802 - acc: 0.8812 - val_loss: 0.5001 - val_acc: 0.8284
Epoch 71/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3760 - acc: 0.8783 - val_loss: 0.4961 - val_acc: 0.8297
Epoch 72/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3755 - acc: 0.8742 - val_loss: 0.4937 - val_acc: 0.8310
Epoch 73/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3774 - acc: 0.8798 - val_loss: 0.5006 - val_acc: 0.8310
Epoch 74/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3756 - acc: 0.8787 - val_loss: 0.4955 - val_acc: 0.8297
Epoch 75/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3780 - acc: 0.8724 - val_loss: 0.4950 - val_acc: 0.8297
Epoch 76/100
22/22 [==============================] - 1s 26ms/step - loss: 0.3794 - acc: 0.8746 - val_loss: 0.4930 - val_acc: 0.8297
Epoch 77/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3614 - acc: 0.8809 - val_loss: 0.4915 - val_acc: 0.8323
Epoch 78/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3644 - acc: 0.8783 - val_loss: 0.4911 - val_acc: 0.8297
Epoch 79/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3680 - acc: 0.8787 - val_loss: 0.4871 - val_acc: 0.8335
Epoch 80/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3637 - acc: 0.8805 - val_loss: 0.4851 - val_acc: 0.8310
Epoch 81/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3647 - acc: 0.8875 - val_loss: 0.4866 - val_acc: 0.8335
Epoch 82/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3602 - acc: 0.8827 - val_loss: 0.4855 - val_acc: 0.8335
Epoch 83/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3544 - acc: 0.8868 - val_loss: 0.4849 - val_acc: 0.8335
Epoch 84/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3616 - acc: 0.8861 - val_loss: 0.4843 - val_acc: 0.8297
Epoch 85/100
22/22 [==============================] - 1s 24ms/step - loss: 0.3642 - acc: 0.8820 - val_loss: 0.4801 - val_acc: 0.8310
Epoch 86/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3660 - acc: 0.8787 - val_loss: 0.4783 - val_acc: 0.8310
Epoch 87/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3470 - acc: 0.8868 - val_loss: 0.4863 - val_acc: 0.8348
Epoch 88/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3525 - acc: 0.8872 - val_loss: 0.4812 - val_acc: 0.8297
Epoch 89/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3622 - acc: 0.8768 - val_loss: 0.4858 - val_acc: 0.8297
Epoch 90/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3623 - acc: 0.8764 - val_loss: 0.4811 - val_acc: 0.8323
Epoch 91/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3501 - acc: 0.8857 - val_loss: 0.4820 - val_acc: 0.8310
Epoch 92/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3457 - acc: 0.8920 - val_loss: 0.4803 - val_acc: 0.8335
Epoch 93/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3566 - acc: 0.8831 - val_loss: 0.4810 - val_acc: 0.8310
Epoch 94/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3542 - acc: 0.8846 - val_loss: 0.4808 - val_acc: 0.8323
Epoch 95/100
22/22 [==============================] - 1s 26ms/step - loss: 0.3480 - acc: 0.8883 - val_loss: 0.4757 - val_acc: 0.8335
Epoch 96/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3441 - acc: 0.8883 - val_loss: 0.4810 - val_acc: 0.8348
Epoch 97/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3539 - acc: 0.8846 - val_loss: 0.4742 - val_acc: 0.8361
Epoch 98/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3461 - acc: 0.8864 - val_loss: 0.4741 - val_acc: 0.8310
Epoch 99/100
22/22 [==============================] - 1s 26ms/step - loss: 0.3323 - acc: 0.8964 - val_loss: 0.4802 - val_acc: 0.8348
Epoch 100/100
22/22 [==============================] - 1s 25ms/step - loss: 0.3497 - acc: 0.8812 - val_loss: 0.4761 - val_acc: 0.8323

Akurasinya terlihat bagus tetapi penting untuk menjalankan langkah evaluasi pada data uji dan memverifikasi model Anda mencapai hasil yang baik pada data yang tidak diunggulkan.

print('Evaluating the model')
model.evaluate(test_data)
Evaluating the model
28/28 [==============================] - 14s 469ms/step - loss: 0.7833 - acc: 0.7750
[0.7832977175712585, 0.7749713063240051]

Memahami model Anda

Ketika melatih classifier, itu berguna untuk melihat matriks kebingungan . Matriks kebingungan memberi Anda pengetahuan mendetail tentang kinerja pengklasifikasi Anda pada data uji.

Pembuat Model sudah membuat matriks kebingungan untuk Anda.

def show_confusion_matrix(confusion, test_labels):
  """Compute confusion matrix and normalize."""
  confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
  axis_labels = test_labels
  ax = sns.heatmap(
      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
      cmap='Blues', annot=True, fmt='.2f', square=True)
  plt.title("Confusion matrix")
  plt.ylabel("True label")
  plt.xlabel("Predicted label")

confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)

png

Menguji model [Opsional]

Anda dapat mencoba model pada sampel audio dari dataset pengujian hanya untuk melihat hasilnya.

Pertama, Anda mendapatkan model penyajian.

serving_model = model.create_serving_model()

print(f'Model\'s input shape and type: {serving_model.inputs}')
print(f'Model\'s output shape and type: {serving_model.outputs}')
Model's input shape and type: [<KerasTensor: shape=(None, 15600) dtype=float32 (created by layer 'audio')>]
Model's output shape and type: [<KerasTensor: shape=(1, 521) dtype=float32 (created by layer 'keras_layer')>, <KerasTensor: shape=(1, 5) dtype=float32 (created by layer 'sequential')>]

Kembali ke audio acak yang Anda muat sebelumnya

# if you want to try another file just uncoment the line below
random_audio = get_random_audio_file()
show_bird_data(random_audio)
Bird name: House Sparrow
Bird code: houspa

jpeg

png

Model yang dibuat memiliki jendela input tetap.

Untuk file audio tertentu, Anda harus membaginya di jendela data dengan ukuran yang diharapkan. Jendela terakhir mungkin perlu diisi dengan nol.

sample_rate, audio_data = wavfile.read(random_audio, 'rb')

audio_data = np.array(audio_data) / tf.int16.max
input_size = serving_model.input_shape[1]

splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)

print(f'Test audio path: {random_audio}')
print(f'Original size of the audio data: {len(audio_data)}')
print(f'Number of windows for inference: {len(splitted_audio_data)}')
Test audio path: /tmpfs/src/temp/tensorflow/lite/g3doc/tutorials/dataset/small_birds_dataset/test/houspa/XC565683.wav
Original size of the audio data: 520992
Number of windows for inference: 34

Anda akan mengulang semua audio yang terbelah dan menerapkan model untuk masing-masing audio.

Model yang baru saja Anda latih memiliki 2 keluaran: keluaran YAMNet asli dan yang baru saja Anda latih. Ini penting karena lingkungan dunia nyata lebih rumit dari sekedar suara burung. Anda dapat menggunakan output YAMNet untuk menyaring audio yang tidak relevan, misalnya, pada kasus penggunaan burung, jika YAMNet tidak mengklasifikasikan Burung atau Hewan, ini mungkin menunjukkan bahwa output dari model Anda mungkin memiliki klasifikasi yang tidak relevan.

Di bawah kedua output dicetak untuk memudahkan memahami hubungan mereka. Sebagian besar kesalahan yang dilakukan model Anda adalah ketika prediksi YAMNet tidak terkait dengan domain Anda (mis.: burung).

print(random_audio)

results = []
print('Result of the window ith:  your model class -> score,  (spec class -> score)')
for i, data in enumerate(splitted_audio_data):
  yamnet_output, inference = serving_model(data)
  results.append(inference[0].numpy())
  result_index = tf.argmax(inference[0])
  spec_result_index = tf.argmax(yamnet_output[0])
  t = spec._yamnet_labels()[spec_result_index]
  result_str = f'Result of the window {i}: ' \
  f'\t{test_data.index_to_label[result_index]} -> {inference[0][result_index].numpy():.3f}, ' \
  f'\t({spec._yamnet_labels()[spec_result_index]} -> {yamnet_output[0][spec_result_index]:.3f})'
  print(result_str)


results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'Mean result: {test_data.index_to_label[result_index]} -> {mean_results[result_index]}')
/tmpfs/src/temp/tensorflow/lite/g3doc/tutorials/dataset/small_birds_dataset/test/houspa/XC565683.wav
Result of the window ith:  your model class -> score,  (spec class -> score)
Result of the window 0:   redcro -> 0.941,     (Silence -> 1.000)
Result of the window 1:   redcro -> 0.508,     (Animal -> 0.940)
Result of the window 2:   redcro -> 0.994,     (Silence -> 1.000)
Result of the window 3:   redcro -> 0.876,     (Silence -> 1.000)
Result of the window 4:   houspa -> 0.958,     (Bird -> 0.974)
Result of the window 5:   redcro -> 0.991,     (Silence -> 1.000)
Result of the window 6:   redcro -> 0.990,     (Silence -> 1.000)
Result of the window 7:   redcro -> 0.424,     (Bird -> 0.944)
Result of the window 8:   houspa -> 0.744,     (Bird -> 0.974)
Result of the window 9:   chcant2 -> 0.998,    (Silence -> 1.000)
Result of the window 10:  chcant2 -> 0.821,    (Silence -> 1.000)
Result of the window 11:  redcro -> 0.860,     (Silence -> 1.000)
Result of the window 12:  redcro -> 0.569,     (Bird -> 0.970)
Result of the window 13:  redcro -> 0.973,     (Silence -> 1.000)
Result of the window 14:  redcro -> 0.973,     (Wild animals -> 0.930)
Result of the window 15:  chcant2 -> 0.532,    (Silence -> 1.000)
Result of the window 16:  redcro -> 0.960,     (Silence -> 1.000)
Result of the window 17:  redcro -> 0.662,     (Bird -> 0.911)
Result of the window 18:  redcro -> 0.899,     (Silence -> 1.000)
Result of the window 19:  redcro -> 0.959,     (Silence -> 1.000)
Result of the window 20:  redcro -> 0.968,     (Silence -> 1.000)
Result of the window 21:  houspa -> 0.994,     (Bird -> 0.990)
Result of the window 22:  redcro -> 0.967,     (Silence -> 1.000)
Result of the window 23:  redcro -> 0.991,     (Silence -> 1.000)
Result of the window 24:  redcro -> 0.981,     (Silence -> 1.000)
Result of the window 25:  houspa -> 0.604,     (Bird -> 0.987)
Result of the window 26:  redcro -> 0.944,     (Silence -> 1.000)
Result of the window 27:  redcro -> 0.938,     (Silence -> 1.000)
Result of the window 28:  redcro -> 0.907,     (Silence -> 1.000)
Result of the window 29:  redcro -> 0.937,     (Silence -> 1.000)
Result of the window 30:  redcro -> 0.940,     (Silence -> 1.000)
Result of the window 31:  redcro -> 0.566,     (Bird -> 0.990)
Result of the window 32:  redcro -> 0.672,     (Silence -> 1.000)
Result of the window 33:  redcro -> 0.776,     (Silence -> 1.000)
Mean result: redcro -> 0.7189410328865051

Mengekspor model

Langkah terakhir adalah mengekspor model Anda untuk digunakan pada perangkat yang disematkan atau di browser.

The export Metode ekspor kedua format untuk Anda.

models_path = './birds_models'
print(f'Exporing the TFLite model to {models_path}')

model.export(models_path, tflite_filename='my_birds_model.tflite')
Exporing the TFLite model to ./birds_models
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
2021-10-07 11:56:43.268990: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/tmplmr5go_i/assets
INFO:tensorflow:Assets written to: /tmp/tmplmr5go_i/assets
2021-10-07 11:56:49.292362: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-10-07 11:56:49.292431: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./birds_models/my_birds_model.tflite

Anda juga dapat mengekspor versi SavedModel untuk disajikan atau digunakan pada lingkungan Python.

model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets
INFO:tensorflow:Assets written to: ./birds_models/saved_model/assets
INFO:tensorflow:Saving labels in ./birds_models/labels.txt
INFO:tensorflow:Saving labels in ./birds_models/labels.txt

Langkah selanjutnya

Anda melakukannya.

Sekarang model baru Anda dapat digunakan pada perangkat mobile menggunakan TFLite AudioClassifier Tugas API .

Anda juga dapat mencoba proses yang sama dengan data Anda sendiri dengan kelas yang berbeda dan di sini adalah dokumentasi untuk Model Maker untuk Klasifikasi Audio .

Juga belajar dari aplikasi end-to-end referensi: Android , iOS .