LM ヘッドを使用した Wav2Vec2 の微調整

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード TF Hub モデルを参照

このノートブックでは、事前にトレーニングされた wav2vec2 モデルを TFHub からロードし、事前にトレーニングされたモデルの上に言語モデリングヘッド(LM)を追加することにより、LibriSpeech データセットで微調整します。基礎となるタスクは、自動音声認識のモデルを構築することです。つまり、音声が与えられた場合、モデルはそれをテキストに変換できる必要があります。

セットアップ

このノートブックを実行する前に、GPU ランタイムを使用していることを確認してください(Runtime > Change runtime type > GPU)。次のセルは、gsoc-wav2vec2 パッケージとその依存関係をインストールします。

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
The following packages were automatically installed and are no longer required:
  libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2
  libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2
  libparted-fs-resize0 libxmlb2
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libogg-dev libvorbis-dev
The following NEW packages will be installed:
  libflac-dev libogg-dev libsndfile1-dev libvorbis-dev
0 upgraded, 4 newly installed, 0 to remove and 128 not upgraded.
Need to get 908 kB of archives.
After this operation, 4277 kB of additional disk space will be used.
Get:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/main amd64 libogg-dev amd64 1.3.4-0ubuntu1 [161 kB]
Get:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libflac-dev amd64 1.3.3-1ubuntu0.2 [151 kB]
Get:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/main amd64 libvorbis-dev amd64 1.3.6-2ubuntu1 [316 kB]
Get:4 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libsndfile1-dev amd64 1.0.28-7ubuntu0.2 [280 kB]
Fetched 908 kB in 0s (6442 kB/s)
Selecting previously unselected package libogg-dev:amd64.
(Reading database ... 144105 files and directories currently installed.)
Preparing to unpack .../libogg-dev_1.3.4-0ubuntu1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.4-0ubuntu1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../libflac-dev_1.3.3-1ubuntu0.2_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.3-1ubuntu0.2) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../libvorbis-dev_1.3.6-2ubuntu1_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.6-2ubuntu1) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../libsndfile1-dev_1.0.28-7ubuntu0.2_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-7ubuntu0.2) ...
Setting up libogg-dev:amd64 (1.3.4-0ubuntu1) ...
Setting up libvorbis-dev:amd64 (1.3.6-2ubuntu1) ...
Setting up libflac-dev:amd64 (1.3.3-1ubuntu0.2) ...
Setting up libsndfile1-dev (1.0.28-7ubuntu0.2) ...

TFHub を使用したモデルのセットアップ

いくつかのライブラリ/モジュールをインポートすることから開始します。

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
2024-01-11 18:55:18.967645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:55:18.967689: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:55:18.969225: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TF version: 2.15.0

最初に TFHub からモデルをダウンロードし、モデルの署名を hub.KerasLayer でラップして、他の Keras レイヤーと同じようにこのモデルを使用できるようにします。幸い、hub.KerasLayer はたった1行で両方を実行できます。

**注意: **hub.KerasLayer を使用してモデルをロードすると、モデルは少し不透明になりますが、モデルをより細かく制御する必要がある場合は、 tf.keras.models.load_model(...) を使用してモデルをロードできます。

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

モデルエクスポートスクリプトに興味がある場合は、このスクリプトを参照できます。オブジェクト pretrained_layer は、Wav2Vec2Model のフリーズバージョンです。これらの事前トレーニング済みの重みは、このスクリプトを使用して HuggingFacePyTorch 事前トレーニング済みの重みから変換されました。

もともと、wav2vec2 はマスクされた時間ステップの真の量子化された潜在的な音声表現を識別することを目的として、マスクされた言語モデリングアプローチで事前にトレーニングされていました。トレーニングの目的について詳しくは、wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations をお読みください。

ここで、次のセルで役立つ定数とハイパーパラメータをいくつか定義します。モデルのシグネチャは静的シーケンス長 246000 のみを受け入れるため、AUDIO_MAXLEN は意図的に 246000 に設定されています。

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

次のセルでは、pretrained_layer と高密度レイヤー(LM ヘッド)を Keras の Functional API でラップします。

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

各タイムステップで語彙内の各トークンの確率を予測するため、高密度レイヤー(上記で定義)の出力ディメンションは vocab_size です。

トレーニング状態のセットアップ

TensorFlow では、model.call または model.build が初めて呼び出されたときにのみモデルの重みが作成されるため、次のセルでモデルの重みが作成されます。さらに、model.summary() を実行して、トレーニング可能なパラメータの総数を確認します。

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
=================================================================
Total params: 94396320 (360.09 MB)
Trainable params: 94396320 (360.09 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

次に、モデルをトレーニングできるように loss_fn とオプティマイザを定義する必要があります。次のセルで定義されます。簡単にするため、Adam オプティマイザを使用します。CTCLoss は、入力サブパーツを出力サブパーツと簡単に位置合わせできないタスク(ASR など)に使用される一般的な損失タイプです。 CTC-loss の詳細については、このすばらしいブログ投稿をご覧ください。

gsoc-wav2vec2パッケージからの)CTCLossは、configmodel_input_shapedivision_factor の 3 つの引数を受け入れます。division_factor=1 の場合、損失は単純に合計されるため、それに応じて division_factor を渡して、バッチ全体の平均を取得します。

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

データの読み込みと前処理

LibriSpeech データセットを公式ウェブサイトからダウンロードして設定しましょう。

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2024-01-11 18:55:41--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: http://us.openslr.org/resources/12/dev-clean.tar.gz [following]
--2024-01-11 18:55:42--  http://us.openslr.org/resources/12/dev-clean.tar.gz
Resolving us.openslr.org (us.openslr.org)... 46.101.158.64
Connecting to us.openslr.org (us.openslr.org)|46.101.158.64|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  25.8MB/s    in 13s     

2024-01-11 18:55:56 (24.1 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]

**注意: **このノートブックはデモンストレーションのみを目的としているため、dev-clean 構成を使用しており、少量のデータが必要です。完全なトレーニングデータは、LibriSpeech Webサイトから簡単にダウンロードできます。

ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

データセットは LibriSpeech ディレクトリにあります。これらのファイルを調べてみましょう。

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0007.flac', '2428-83705-0037.flac', '2428-83705-0039.flac', '2428-83705-0033.flac', '2428-83705-0001.flac', '2428-83705-0036.flac', '2428-83705-0027.flac', '2428-83705-0032.flac', '2428-83705-0024.flac', '2428-83705-0023.flac', '2428-83705-0005.flac', '2428-83705-0019.flac', '2428-83705-0030.flac', '2428-83705-0035.flac', '2428-83705-0025.flac', '2428-83705-0041.flac', '2428-83705-0013.flac', '2428-83705-0009.flac', '2428-83705-0026.flac', '2428-83705-0006.flac', '2428-83705-0043.flac', '2428-83705-0004.flac', '2428-83705-0029.flac', '2428-83705-0014.flac', '2428-83705-0010.flac', '2428-83705-0000.flac', '2428-83705-0034.flac', '2428-83705-0017.flac', '2428-83705-0002.flac', '2428-83705-0038.flac', '2428-83705-0012.flac', '2428-83705-0042.flac', '2428-83705-0022.flac', '2428-83705-0015.flac', '2428-83705-0028.flac', '2428-83705-0040.flac', '2428-83705-0011.flac', '2428-83705-0020.flac', '2428-83705-0021.flac', '2428-83705-0008.flac', '2428-83705-0016.flac', '2428-83705-0031.flac', '2428-83705-0018.flac', '2428-83705-0003.flac']

なるほど。各サブディレクトリには、多くの .flac ファイルと .txt ファイルがあります。.txt ファイルには、そのサブディレクトリに存在するすべての音声サンプル(つまり、.flac ファイル)のテキスト文字起こしが含まれています。

このテキストデータは次のようにロードできます。

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

同様に、.flac ファイルから音声サンプルをロードするための関数を定義します。

wav2vec2 は 16K の頻度で事前トレーニングされているため、REQUIRED_SAMPLE_RATE16000 に設定されます。頻度によるデータ分布の大きな変化なしに、微調整することをお勧めします。

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

次に、ランダムなサンプルをいくつか選び、視覚化してみましょう。

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: BUT WHY ON THAT ACCOUNT THEY SHOULD PITY ME I ALTOGETHER FAIL TO UNDERSTAND 
Audio:

次に、すべての音声とテキストのサンプルを組み合わせて、その目的のために(次のセルで)関数を定義します。

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

いくつかのサンプルを見てみましょう...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([6.71386719e-04, 6.71386719e-04, 5.49316406e-04, ...,
         2.44140625e-04, 2.44140625e-04, 3.05175781e-05]),
  "THE GIRL IS FRETTING BUT YOU DON'T SEEM TO NOTICE IT"),
 (array([-6.10351562e-05, -6.10351562e-05, -3.05175781e-05, ...,
         -2.13623047e-04, -9.15527344e-05, -3.05175781e-05]),
  'I CANNOT PRETEND TO EXPLAIN WHY EXCEPT ON THE SUPPOSITION THAT ROMANCE IS DEAD AT LEAST IN THAT CIRCLE OF SOCIETY IN WHICH THE SNELLINGS MOVE'),
 (array([1.22070312e-04, 1.22070312e-04, 9.15527344e-05, ...,
         2.44140625e-04, 2.44140625e-04, 1.83105469e-04]),
  'P S THE CARDS ARE OUT FOR THE WEDDING'),
 (array([ 2.74658203e-04,  2.74658203e-04,  1.22070312e-04, ...,
         -1.83105469e-04, -1.22070312e-04, -9.15527344e-05]),
  'WE HAVE ALL BEEN GIVING MARY ANN PRESENTS AND I SUPPOSE YOU MISTER WHITING HAVE BEEN GIVING HER SOMETHING TOO'),
 (array([0.00054932, 0.00033569, 0.00021362, ..., 0.00061035, 0.00054932,
         0.00048828]),
  'BUT IT IS QUITE PLAIN TO ME THAT ALL THE ARRANGEMENTS FOR MY WEDDING ARE GOING TO BE MADE BY THE SNELLINGS')]

注意: このノートブックで少量のデータセットを操作する際に、このデータはメモリに読み込まれます。ただし、完全なデータセット(〜300 GB)でトレーニングするには、データを遅延ロードする必要があります。このスクリプトを参照して、詳細を確認できます。

今すぐデータを前処理しましょう!!!

まず、gsoc-wav2vec2 パッケージを使用してトークナイザーとプロセッサーを定義します。次に、非常に簡単な前処理を行います。processor はフレーム軸に対して生の音声を正規化し、tokenizer はモデル出力を文字列に変換し(定義された語彙を使用)、特別なトークンを削除します(トークナイザーの構成によって異なります)。

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

次に、上記のセルで定義した前処理関数を呼び出す Python ジェネレーターを定義します。

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

tf.data.Dataset のセットアップ

次のセルは、.from_generator(...) メソッドを使用して tf.data.Dataset オブジェクトをセットアップします。上記のセルで定義した generator オブジェクトを使用します。

**注意: **分散トレーニング(特に TPU)の場合、.from_generator(...) は現在機能しないため、.tfrecord 形式で保存されたデータでトレーニングすることをお勧めします(注意: TFRecord は、TPU が最大限に機能するように、理想的には GCS バケット内に保存する必要があります)。

LibriSpeech データを tfrecords に変換する方法の詳細については、このスクリプト を参照してください。

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

データセットを複数のバッチに渡すため、次のセルでバッチを準備しましょう。ここで、バッチ内のすべてのシーケンスを一定の長さにパディングする必要があります。そのために .padded_batch(...) メソッドを使用します。

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

アクセラレータ(GPU/TPU など)は非常に高速であり、データの読み込み部分が CPU で発生するため、トレーニング中にデータの読み込み(および前処理)がボトルネックになることがよくあります。これにより、特に多くのオンライン前処理が含まれる場合や、データが GCS バケットからオンラインでストリーミングされる場合に、トレーニング時間が大幅に増加する可能性があります。これらの問題に対応するために、tf.data.Dataset.prefetch(...) メソッドを提供します。この方法は、モデルが現在のバッチで(GPU/TPU で)予測を行っている間に、次のいくつかのバッチを(CPU で)並行して準備するのに役立ちます。

dataset = dataset.prefetch(tf.data.AUTOTUNE)

このノートブックはデモンストレーションを目的に作成されているため、最初に num_train_batches を取得し、それだけでトレーニングを実行します。ただし、データセット全体でトレーニングすることをお勧めします。同様に、num_val_batches のみを評価します。

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

モデルのトレーニング

モデルをトレーニングするため、モデルを .compile(...) でコンパイルした後に、.fit(...) メソッドを直接呼び出します。

model.compile(optimizer, loss=loss_fn)

上記のセルは、トレーニング状態を設定します。これで、.fit(...) メソッドを使用してトレーニングを開始できます。

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1531: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1531: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1508: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1508: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1704999382.359793   73622 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
10/10 [==============================] - 58s 2s/step - loss: 900.0211 - val_loss: 538.8968
Epoch 2/3
10/10 [==============================] - 19s 2s/step - loss: 429.8926 - val_loss: 375.9716
Epoch 3/3
10/10 [==============================] - 19s 2s/step - loss: 403.2861 - val_loss: 401.6767
{'loss': [900.0211181640625, 429.892578125, 403.2861022949219],
 'val_loss': [538.8967895507812, 375.9715576171875, 401.6766662597656]}

後で推論を実行できるように、.save(...) メソッドを使用してモデルを保存しましょう。TFHub のドキュメントに従って、この SavedModel を TFHub にエクスポートすることもできます。

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

注意: このモデルは推論のみに使用するため、include_optimizer=False を設定しています。

評価

次に、検証データセットの単語誤り率を計算します

単語誤り率(WER)は、自動音声認識システムのパフォーマンスを測定するための一般的な指標です。WER は、単語レベルで機能するレーベンシュタイン距離から導出されます。単語誤り率は WER = (S + D + I) / N = (S + D + I) / (S + D + C) で計算でき、S は置換の数、D は削除の数、I は挿入の数、C は正しい単語の数、N は参照内の単語の数です(N=S+D+C)。この値は、誤って予測された単語の割合を示します。

WER の詳細については、この論文を参照してください。

HuggingFace データセットライブラリの load_metric(...) 関数を使用します。最初に pip を使用して datasets ライブラリをインストールしてから、metricオブジェクトを定義しましょう。

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
/tmpfs/tmp/ipykernel_73282/1786190190.py:4: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
  metric = load_metric("wer")
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/datasets/load.py:752: FutureWarning: The repository for wer contains custom code which must be executed to correctly load the metric. You can inspect the repository content at https://raw.githubusercontent.com/huggingface/datasets/2.16.1/metrics/wer/wer.py
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
  warnings.warn(
Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

次に、検証データの評価を実行します。

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2024-01-11 18:57:54.172941: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:59] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform
W0000 00:00:1704999485.261677   73282 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update

tokenizer.decode(...) メソッドを使用して、予測とラベルをテキストにデコードし直し、後で WER 計算用のメトリックに追加します。

では、次のセルでメトリック値を計算しましょう。

metric.compute()
1.0

**注意: **モデルは非常に小さなデータでトレーニングされており、ASR のようなタスクでは、音声からテキストへのマッピングを学習するために大量のデータが必要になることが多いため、ここでのメトリック値は意味がありません。良い結果を得るには、おそらく大きなデータでトレーニングする必要があります。このノートブックは、事前にトレーニングされた音声モデルを微調整するためのテンプレートを提供します。

推論

トレーニングプロセスに満足し、モデルを save_dir に保存したので、このモデルを推論に使用する方法を確認します。

まず、tf.keras.models.load_model(...) を使用してモデルをロードします。

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

推論を実行するための音声サンプルをダウンロードしましょう。次のサンプルを音声サンプルに置き換えることもできます。

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2024-01-11 18:58:15--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://github.com/thevasudevgupta/gsoc-wav2vec2/raw/main/data/SA2.wav [following]
--2024-01-11 18:58:15--  https://github.com/thevasudevgupta/gsoc-wav2vec2/raw/main/data/SA2.wav
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/thevasudevgupta/gsoc-wav2vec2/main/data/SA2.wav [following]
--2024-01-11 18:58:15--  https://raw.githubusercontent.com/thevasudevgupta/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.01s   

2024-01-11 18:58:16 (6.16 MB/s) - ‘SA2.wav’ saved [94252/94252]

soundfile.read(...) を使用して音声サンプルを読み取り、モデルのシグネチャを満たすために AUDIO_MAXLEN にパディングします。次に、
Wav2Vec2Processor インスタンスを使用してその音声サンプルを正規化し、モデルにフィードします。

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 0.26906326, -1.0434859 , -1.4027557 , ..., -0.45367476,
         -0.4187361 , -0.40499282],
        [ 0.31648695, -1.0879228 , -1.4781278 , ..., -0.3010028 ,
         -0.43490997, -0.43344948],
        [ 0.36363643, -1.1803617 , -1.4061254 , ..., -0.37722936,
         -0.42821348, -0.5352294 ],
        ...,
        [ 0.14183159, -0.62208396, -1.3148222 , ..., -0.96311665,
         -0.75722384, -0.46243808],
        [ 0.14848918, -0.61940384, -1.3114994 , ..., -0.9643733 ,
         -0.7576248 , -0.46728024],
        [ 0.15444012, -0.6074123 , -1.3052613 , ..., -0.95263773,
         -0.7675285 , -0.4688016 ]]], dtype=float32)>

上で定義した Wav2Vec2tokenizer インスタンスを使用して、数値をデコードしてテキストシーケンスに戻しましょう。

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['O']

モデルはこのノートブックの大きなデータでトレーニングされたことがないため、この予測は非常にランダムです(このノートブックは完全なトレーニングを行うためのものではないためです)。このモデルを完全な LibriSpeech データセットでトレーニングすると、適切な予測が得られます。

ようやくこのノートブックの最後にたどり着きました。しかし、音声関連のタスクについての TensorFlow の学習はこれで終わりではありません。このリポジトリには、さらにすばらしいチュートリアルがいくつか含まれています。このノートブックでバグが発生した場合は、ここで問題を作成してください。