YAMNet によるサウンドの分類

TensorFlow.org で表示 Google Colabで実行 GitHub でソースを表示 ノートブックをダウンロード TF Hub モデルを参照

YAMNet は、521 個のオーディオイベントクラスを、YAMNet がトレーニングに使用した AudioSet-YouTube コーパスから予測するディープネットです。Mobilenet_v1 という Depthwise-Separable Convolution(深さ方向に分離可能な畳み込み)アーキテクチャを使用しています。

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import csv

import matplotlib.pyplot as plt
from IPython.display import Audio
from scipy.io import wavfile
2024-01-11 17:42:53.303800: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:42:53.303846: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:42:53.305309: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow Hub からモデルを読み込みます。

注意: ドキュメントを読むには、モデルの url に従ってください。

# Load the model.
model = hub.load('https://tfhub.dev/google/yamnet/1')

models アセットから labels ファイルが読み込まれます。これは model.class_map_path() にあります。class_names 変数で読み込みます。

# Find the name of the class with the top score when mean-aggregated across frames.
def class_names_from_csv(class_map_csv_text):
  """Returns list of class names corresponding to score vector."""
  class_names = []
  with tf.io.gfile.GFile(class_map_csv_text) as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
      class_names.append(row['display_name'])

  return class_names

class_map_path = model.class_map_path().numpy()
class_names = class_names_from_csv(class_map_path)

読み込まれたオーディオが適切な sample_rate(16K)であることを確認して変換するメソッドを追加します。これがなければ、モデルの結果に影響があります。

def ensure_sample_rate(original_sample_rate, waveform,
                       desired_sample_rate=16000):
  """Resample waveform if required."""
  if original_sample_rate != desired_sample_rate:
    desired_length = int(round(float(len(waveform)) /
                               original_sample_rate * desired_sample_rate))
    waveform = scipy.signal.resample(waveform, desired_length)
  return desired_sample_rate, waveform

サウンドファイルのダウンロードと準備

ここでは、wav ファイルをダウンロードして聴くことができるようにします。利用できるファイルがある場合は、Colab にアップロードしてそれを使用してください。

注意: 期待されるオーディオファイルは、サンプリングレートが 16kHz の mono wav ファイルである必要があります。

curl -O https://storage.googleapis.com/audioset/speech_whistling2.wav
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  153k  100  153k    0     0  1690k      0 --:--:-- --:--:-- --:--:-- 1671k
curl -O https://storage.googleapis.com/audioset/miaow_16k.wav
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  210k  100  210k    0     0  2844k      0 --:--:-- --:--:-- --:--:-- 2844k
# wav_file_name = 'speech_whistling2.wav'
wav_file_name = 'miaow_16k.wav'
sample_rate, wav_data = wavfile.read(wav_file_name, 'rb')
sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)

# Show some basic information about the audio.
duration = len(wav_data)/sample_rate
print(f'Sample rate: {sample_rate} Hz')
print(f'Total duration: {duration:.2f}s')
print(f'Size of the input: {len(wav_data)}')

# Listening to the wav file.
Audio(wav_data, rate=sample_rate)
Sample rate: 16000 Hz
Total duration: 6.73s
Size of the input: 107698
/tmpfs/tmp/ipykernel_10434/2211628228.py:3: WavFileWarning: Chunk (non-data) not understood, skipping it.
  sample_rate, wav_data = wavfile.read(wav_file_name, 'rb')

wav_data を、[-1.0, 1.0] の値に正規化する必要があります(モデルのドキュメントで指示されています)。

waveform = wav_data / tf.int16.max

モデルを実行する

これは簡単なステップです。準備済みのデータを使用して、モデルを呼び出し、スコア、埋め込み、およびスペクトログラムを取得します。

使用するメインの結果は、スコアです。スペクトログラムについては、後で視覚化を行うために使用します。

# Run the model, check the output.
scores, embeddings, spectrogram = model(waveform)
scores_np = scores.numpy()
spectrogram_np = spectrogram.numpy()
infered_class = class_names[scores_np.mean(axis=0).argmax()]
print(f'The main sound is: {infered_class}')
The main sound is: Animal

視覚化

YAMNet は、視覚化に使用できる追加情報も返します。波形、スペクトログラム、および推論された上位クラスを確認してみましょう。

plt.figure(figsize=(10, 6))

# Plot the waveform.
plt.subplot(3, 1, 1)
plt.plot(waveform)
plt.xlim([0, len(waveform)])

# Plot the log-mel spectrogram (returned by the model).
plt.subplot(3, 1, 2)
plt.imshow(spectrogram_np.T, aspect='auto', interpolation='nearest', origin='lower')

# Plot and label the model output scores for the top-scoring classes.
mean_scores = np.mean(scores, axis=0)
top_n = 10
top_class_indices = np.argsort(mean_scores)[::-1][:top_n]
plt.subplot(3, 1, 3)
plt.imshow(scores_np[:, top_class_indices].T, aspect='auto', interpolation='nearest', cmap='gray_r')

# patch_padding = (PATCH_WINDOW_SECONDS / 2) / PATCH_HOP_SECONDS
# values from the model documentation
patch_padding = (0.025 / 2) / 0.01
plt.xlim([-patch_padding-0.5, scores.shape[0] + patch_padding-0.5])
# Label the top_N classes.
yticks = range(0, top_n, 1)
plt.yticks(yticks, [class_names[top_class_indices[x]] for x in yticks])
_ = plt.ylim(-0.5 + np.array([top_n, 0]))

png