LM 헤드로 Wav2Vec2 미세 조정

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기 TF Hub 모델 보기

이 노트북에서는 TFHub에서 사전 훈련된 wav2vec2 모델을 로드하고 사전 훈련된 모델 위에 LM(Language Modeling) 헤드를 추가하여 LibriSpeech 데이터세트에서 미세 조정합니다. 기본 작업은 자동 음성 인식을 위한 모델을 구축하는 것입니다. 즉, 일부 음성이 주어지면 해당 모델이 이를 텍스트로 변환할 수 있어야 합니다.

설정

이 노트북을 실행하기 전에 GPU 런타임(Runtime > Change runtime type > GPU)에 있는지 확인하십시오. 다음 셀은 gsoc-wav2vec2 패키지 및 종속성을 설치합니다.

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
The following packages were automatically installed and are no longer required:
  libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2
  libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2
  libparted-fs-resize0
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libflac8 libogg-dev libvorbis-dev
The following NEW packages will be installed:
  libflac-dev libogg-dev libsndfile1-dev libvorbis-dev
The following packages will be upgraded:
  libflac8
1 upgraded, 4 newly installed, 0 to remove and 169 not upgraded.
Need to get 1012 kB of archives.
After this operation, 4279 kB of additional disk space will be used.
Get:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libflac8 amd64 1.3.3-1ubuntu0.1 [103 kB]
Get:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/main amd64 libogg-dev amd64 1.3.4-0ubuntu1 [161 kB]
Get:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libflac-dev amd64 1.3.3-1ubuntu0.1 [151 kB]
Get:4 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/main amd64 libvorbis-dev amd64 1.3.6-2ubuntu1 [316 kB]
Get:5 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libsndfile1-dev amd64 1.0.28-7ubuntu0.1 [280 kB]
Fetched 1012 kB in 0s (23.0 MB/s)
(Reading database ... 140478 files and directories currently installed.)
Preparing to unpack .../libflac8_1.3.3-1ubuntu0.1_amd64.deb ...
Unpacking libflac8:amd64 (1.3.3-1ubuntu0.1) over (1.3.3-1build1) ...
Selecting previously unselected package libogg-dev:amd64.
Preparing to unpack .../libogg-dev_1.3.4-0ubuntu1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.4-0ubuntu1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../libflac-dev_1.3.3-1ubuntu0.1_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.3-1ubuntu0.1) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../libvorbis-dev_1.3.6-2ubuntu1_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.6-2ubuntu1) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../libsndfile1-dev_1.0.28-7ubuntu0.1_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-7ubuntu0.1) ...
Setting up libogg-dev:amd64 (1.3.4-0ubuntu1) ...
Setting up libflac8:amd64 (1.3.3-1ubuntu0.1) ...
Setting up libvorbis-dev:amd64 (1.3.6-2ubuntu1) ...
Setting up libflac-dev:amd64 (1.3.3-1ubuntu0.1) ...
Setting up libsndfile1-dev (1.0.28-7ubuntu0.1) ...
Processing triggers for libc-bin (2.31-0ubuntu9.9) ...

TFHub를 사용한 모델 설정

먼저 일부 라이브러리/모듈을 가져와 설정을 시작하겠습니다.

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
2022-12-14 21:44:58.424149: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:44:58.424252: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:44:58.424262: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
TF version: 2.11.0

먼저 TFHub에서 모델을 다운로드하고 다른 Keras 레이어처럼 이 모델을 사용할 수 있도록 hub.KerasLayer로 모델 서명을 래핑합니다. 다행히 hub.KerasLayer는 단 한 줄에서 두 작업을 모두 수행할 수 있습니다.

참고: hub.KerasLayer를 사용하여 모델을 로드할 때 모델이 약간 불투명해 지지만 때때로 모델에 대한 더 세밀한 제어가 필요한 경우 tf.keras.models.load_model(...)을 사용하여 모델을 로드할 수 있습니다.

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.

모델 내보내기 스크립트에 관심이 있는 경우 이 스크립트를 참조할 수 있습니다. pretrained_layer 객체는 Wav2Vec2Model의 고정 버전입니다. 이 사전 훈련된 가중치는 이 스크립트를 사용하여 HuggingFace PyTorch 사전 훈련된 가중치에서 변환되었습니다.

원래 wav2vec2는 마스킹된 시간 단계에 대한 진정한 양자화된 잠재 음성 표현을 식별하기 위한 목적으로 마스킹된 언어 모델링 접근 방식으로 사전 훈련되었습니다. 훈련 목표에 대한 자세한 내용은 wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations에서 확인할 수 있습니다.

이제 우리는 다음 몇 개의 셀에서 유용한 몇 가지 상수와 하이퍼 매개변수를 정의할 것입니다. 모델 서명이 246000의 정적 시퀀스 길이만 수용할 때 AUDIO_MAXLEN246000으로 의도적으로 설정됩니다.

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

다음 셀에서 우리는 pretrained_layer와 고밀도 레이어(LM 헤드)를 Keras의 함수형 API로 래핑할 것입니다.

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

밀집 레이어(위에 정의됨)은 각 시간 단계에서 어휘의 각 토큰에 대한 확률을 예측하기 원할 때 vocab_size의 출력 차원을 갖습니다.

훈련 상태 설정

TensorFlow에서 모델 가중치는 model.call 또는 model.build가 처음 호출될 때만 빌드되므로 다음 셀이 우리를 위해 모델 가중치를 빌드합니다. 또한 model.summary()를 실행하여 훈련 가능한 총 매개변수 수를 확인합니다.

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
=================================================================
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________

이제 모델을 훈련할 수 있도록 loss_fn과 옵티마이저를 정의해야 합니다. 다음 셀은 우리를 위해 그러한 작업을 수행할 것입니다. 단순화를 위해 Adam 옵티마이저를 사용할 것입니다. CTCLoss는 입력 하위 부분을 출력 하위 부분과 쉽게 정렬할 수 없는 작업(예: ASR)에 사용되는 일반적인 손실 유형입니다. 이 놀라운 블로그 게시물에서 CTC 손실에 대해 자세히 알아볼 수 있습니다.

CTCLoss(gsoc-wav2vec2 패키지에서)는 config, model_input_shapedivision_factor 등 3가지 인수를 허용합니다. division_factor=1이면 손실이 단순히 합산되므로 그에 따라 division_factor를 전달하여 배치에 대한 평균을 구할 수 있습니다.

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

데이터 로드 및 전처리

이제 공식 웹 사이트에서 LibriSpeech 데이터세트를 다운로드하여 설정해 보겠습니다.

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2022-12-14 21:45:19--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: http://us.openslr.org/resources/12/dev-clean.tar.gz [following]
--2022-12-14 21:45:19--  http://us.openslr.org/resources/12/dev-clean.tar.gz
Resolving us.openslr.org (us.openslr.org)... 46.101.158.64
Connecting to us.openslr.org (us.openslr.org)|46.101.158.64|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  26.3MB/s    in 13s     

2022-12-14 21:45:33 (24.8 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]

참고: 이 노트북은 데모 전용일 때 dev-clean 구성을 사용하고 있으므로 소량의 데이터만 필요합니다. 전체 훈련 데이터는 LibriSpeech 웹사이트에서 쉽게 다운로드할 수 있습니다.

ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

데이터세트는 LibriSpeech 디렉터리에 있습니다. 이 파일들을 살펴보겠습니다.

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0014.flac', '2428-83705-0027.flac', '2428-83705-0040.flac', '2428-83705-0021.flac', '2428-83705-0030.flac', '2428-83705-0038.flac', '2428-83705-0029.flac', '2428-83705-0042.flac', '2428-83705-0002.flac', '2428-83705-0031.flac', '2428-83705-0034.flac', '2428-83705-0036.flac', '2428-83705-0024.flac', '2428-83705-0017.flac', '2428-83705-0033.flac', '2428-83705-0026.flac', '2428-83705-0001.flac', '2428-83705-0032.flac', '2428-83705-0028.flac', '2428-83705-0015.flac', '2428-83705-0022.flac', '2428-83705-0010.flac', '2428-83705-0035.flac', '2428-83705-0039.flac', '2428-83705-0020.flac', '2428-83705-0006.flac', '2428-83705-0003.flac', '2428-83705-0043.flac', '2428-83705-0000.flac', '2428-83705-0018.flac', '2428-83705-0009.flac', '2428-83705-0013.flac', '2428-83705-0025.flac', '2428-83705-0019.flac', '2428-83705-0008.flac', '2428-83705-0037.flac', '2428-83705-0041.flac', '2428-83705-0011.flac', '2428-83705-0007.flac', '2428-83705-0016.flac', '2428-83705-0005.flac', '2428-83705-0004.flac', '2428-83705-0023.flac', '2428-83705-0012.flac']

좋습니다. 각 하위 디렉터리에는 많은 .flac 파일과 .txt 파일이 있습니다. .txt 파일에는 해당 하위 디렉터리에 있는 모든 음성 샘플(예: .flac 파일)에 대한 텍스트 기록이 포함되어 있습니다.

이 텍스트 데이터를 다음과 같이 로드할 수 있습니다.

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

마찬가지로 .flac 파일에서 음성 샘플을 로드하는 함수를 정의합니다.

REQUIRED_SAMPLE_RATE는 wav2vec2가 16K 빈도로 사전 훈련될 때 16000으로 설정됩니다. 따라서 빈도로 인한 데이터 분포의 큰 변화 없이 미세 조정하는 것이 좋습니다.

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

이제 무작위 샘플을 선택하고 시각화를 시도해보겠습니다.

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: I NEVER SAW PEOPLE LIKE THE SNELLINGS FOR POSSESSING RELATIVES IN ALL SORTS OF LINES 
Audio:

이제 모든 음성 및 텍스트 샘플을 결합하고 다음 셀에서 해당 목적을 위한 함수를 정의하겠습니다.

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

몇 가지 샘플을 살펴보겠습니다...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([-0.00027466, -0.00033569, -0.00036621, ...,  0.00021362,

          0.        ,  0.        ]),
  'THERE SHE OWNS A COTTAGE OR IT MAY BE A PIGSTYE FOR ALL I KNOW'),
 (array([ 0.00018311,  0.00067139, -0.00033569, ..., -0.00033569,
         -0.00027466, -0.00027466]),
  'HERS HAS BEEN PRODIGIOUS'),
 (array([-0.00015259,  0.00021362,  0.00042725, ..., -0.00048828,
         -0.00094604, -0.00115967]),
  'WE ARE GOING FOR OUR HONEYMOON TO ITALY AND THE SOUTH OF FRANCE'),
 (array([0.00036621, 0.00027466, 0.00015259, ..., 0.00039673, 0.00048828,
         0.00079346]),
  'I SHALL MAKE PAPA GIVE ME FIVE HUNDRED POUNDS AT LEAST'),
 (array([ 3.66210938e-04,  1.52587891e-04,  4.27246094e-04, ...,
         -3.05175781e-05,  3.05175781e-05,  9.15527344e-05]),
  'THE FACT OF HAVING GIVEN MARY ANN A WEDDING PRESENT SEEMS TO FILL THEM WITH A FEELING OF RANCOROUS ACIDITY WHICH TO ME IS INEXPLICABLE')]

참고: 이 노트북에서 소량의 데이터 세트로 작업할 때 이 데이터를 메모리에 로드합니다. 그러나 전체 데이터세트(약 300GB)를 훈련하기 위해서는 데이터를 느리게 로드해야 합니다. 이에 대한 자세한 내용은 이 스크립트를 참조하십시오.

이제 데이터를 전처리하도록 하겠습니다!!!

먼저 gsoc-wav2vec2 패키지를 사용하여 토크나이저와 프로세서를 정의하겠습니다. 그런 다음, 아주 간단한 전처리 작업을 수행하겠습니다. processor는 프레임 축과 관련된 원시 음성을 정규화하고 tokenizer는 모델 출력을 문자열로 변환(정의된 어휘 사용)하고 특수 토큰 제거를 처리(토큰나이저 구성에 따라 다름)합니다.

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

이제 위의 셀에서 정의한 전처리 함수를 호출하기 위해 파이썬 생성기를 정의하겠습니다.

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

tf.data.Dataset 설정

다음 셀은 .from_generator(...) 메서드를 사용하여 tf.data.Dataset 객체를 설정합니다. 우리는 위의 셀에서 정의한 generator 객체를 사용할 것입니다.

참고: 분산 훈련(특히 TPU에서)의 경우 .from_generator(...)는 현재 작동하지 않으며 .tfrecord 형식으로 저장된 데이터에 대해 훈련하는 것이 좋습니다(참고: TPU가 최대한 작동하려면 TFRecord를 GCS 버킷 안에 이상적으로 저장해야 합니다.).

LibriSpeech 데이터를 tfrecord로 변환하는 방법에 대한 자세한 내용은 이 스크립트를 참조하세요.

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

데이터세트를 여러 배치로 전달할 것이므로 다음 셀에서 배치를 준비하겠습니다. 이제 배치의 모든 시퀀스를 일정한 길이로 채워야 합니다. 이를 위해 .padded_batch(...) 메서드를 사용할 것입니다.

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

가속기(예: GPU/TPU)는 매우 빠르며 데이터 로딩 부분이 CPU에서 발생할 때 훈련 중에 데이터 로딩(및 사전 처리)이 병목 현상이 되는 경우가 많습니다. 이는 특히 온라인 전처리가 많이 포함되거나 데이터가 GCS 버킷에서 온라인으로 스트리밍되는 경우 훈련 시간을 크게 늘릴 수 있습니다. 이러한 문제를 처리하기 위해 tf.data.Dataset.prefetch(...) 메서드를 제공합니다. 이 메서드는 모델이 현재 배치에서 GPU/TPU를 예측하는 동안 CPU에서 병렬로 다음 몇 개의 배치를 준비하는 데 도움이 됩니다.

dataset = dataset.prefetch(tf.data.AUTOTUNE)

이 노트북은 데모용으로 만들어졌기 때문에 먼저 num_train_batches를 가져와 그것에 대해 훈련할 것입니다. 하지만 전체 데이터세트를 훈련하는 것이 좋습니다. 마찬가지로 num_val_batches만 평가할 것입니다.

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

모델 훈련

모델을 훈련하기 위해 .compile(...)으로 모델을 컴파일한 후 .fit(...) 메서드를 직접 호출하겠습니다.

model.compile(optimizer, loss=loss_fn)

위의 셀은 훈련 상태를 설정합니다. 이제 .fit(...) 메서드로 훈련을 시작할 수 있습니다.

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1467: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1467: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1450: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1450: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
10/10 [==============================] - 59s 2s/step - loss: 1081.9149 - val_loss: 495.7457
Epoch 2/3
10/10 [==============================] - 17s 2s/step - loss: 815.8112 - val_loss: 520.3546
Epoch 3/3
10/10 [==============================] - 17s 2s/step - loss: 632.2935 - val_loss: 433.4796
{'loss': [1081.9149169921875, 815.8111572265625, 632.2935180664062],
 'val_loss': [495.74566650390625, 520.3545532226562, 433.4796142578125]}

나중에 추론을 수행할 수 있도록 .save(...) 메서드로 모델을 저장하겠습니다. TFHub 설명서에 따라 이 저장된 모델을 TFHub로 내보낼 수도 있습니다.

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 342). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

참고: 이 모델을 추론용으로만 사용하기 원할 때 include_optimizer=False를 설정합니다.

평가

이제 검증 데이터세트에 대해 단어 오류율을 계산히겠습니다.

WER(단어 오류율)은 자동 음성 인식 시스템의 성능을 측정하기 위한 일반적인 메트릭입니다. WER은 Levenshtein 거리에서 파생되며 단어 수준에서 작용합니다. 단어 오류율은 WER = (S + D + I) / N = (S + D + I) / (S + D + C) 공식으로 계산할 수 있습니다. 여기에서 S는 대체 횟수, D는 삭제 횟수 , I는 삽입 횟수, C는 올바른 단어 수, N은 참조 단어 수(N=S+D+C)입니다. 이 값은 잘못 예측된 단어의 백분율을 나타냅니다.

이 문서를 참조하여 WER에 대해 자세히 알아볼 수 있습니다.

HuggingFace 데이터세트 라이브러리의 load_metric(...) 함수를 사용합니다. 먼저 pip를 사용하여 datasets 라이브러리를 설치한 다음 metric 객체를 정의하겠습니다.

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
/tmpfs/tmp/ipykernel_89208/1786190190.py:4: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
  metric = load_metric("wer")
Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

이제 검증 데이터를 평가하겠습니다.

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2022-12-14 21:47:30.289046: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:57] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform

우리는 예측과 레이블을 다시 텍스트로 디코딩하기 위해 tokenizer.decode(...) 메서드를 사용하고 있으며 나중에 WER 계산을 위해 메트릭에 추가할 것입니다.

이제 다음 셀에서 메트릭 값을 계산하겠습니다.

metric.compute()
1.0

참고: 여기에서 메트릭 값은 모델이 매우 작은 데이터에 대해 훈련되고 ASR과 같은 작업은 음성에서 텍스트로의 매핑을 학습하기 위해 종종 많은 양의 데이터가 필요하기 때문에 의미가 없습니다. 좋은 결과를 얻으려면 대용량 데이터에 대해 훈련해야 합니다. 이 노트북은 사전 훈련된 음성 모델을 미세 조정할 수 있는 템플릿을 제공합니다.

추론

이제 훈련 과정에 만족하고 모델을 save_dir에 저장했으므로 이 모델을 추론에 사용할 수 있는 방법을 살펴보겠습니다.

먼저 tf.keras.models.load_model(...)을 사용하여 모델을 로드하겠습니다.

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

추론을 수행하기 위한 몇 가지 음성 샘플을 다운로드하겠습니다. 다음 샘플을 음성 샘플로 대체할 수도 있습니다.

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2022-12-14 21:47:47--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://github.com/thevasudevgupta/gsoc-wav2vec2/raw/main/data/SA2.wav [following]
--2022-12-14 21:47:47--  https://github.com/thevasudevgupta/gsoc-wav2vec2/raw/main/data/SA2.wav
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/thevasudevgupta/gsoc-wav2vec2/main/data/SA2.wav [following]
--2022-12-14 21:47:47--  https://raw.githubusercontent.com/thevasudevgupta/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.01s   

2022-12-14 21:47:48 (6.26 MB/s) - ‘SA2.wav’ saved [94252/94252]

이제 soundfile.read(...)를 사용하여 음성 샘플을 읽고 모델 서명을 충족하기 위해 AUDIO_MAXLEN 으로 채우겠습니다. 그런 다음 Wav2Vec2Processor 인스턴스를 사용하여 음성 샘플을 정규화하고 이를 모델에 제공하겠습니다.

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 3.1169512 ,  0.262896  , -0.9379    , ..., -0.54652435,
          0.11307168, -1.1297362 ],
        [ 3.1180236 ,  0.26138636, -0.937152  , ..., -0.54561067,
          0.11075146, -1.1301531 ],
        [ 3.1179123 ,  0.26053834, -0.9376069 , ..., -0.5464175 ,
          0.11459835, -1.1299919 ],
        ...,
        [ 3.12473   ,  0.2649066 , -0.94297856, ..., -0.5347064 ,
          0.10750955, -1.1382717 ],
        [ 3.125088  ,  0.26431754, -0.94275814, ..., -0.53419125,
          0.10702407, -1.1383224 ],
        [ 3.12526   ,  0.2634962 , -0.9422261 , ..., -0.53358746,
          0.10673241, -1.1382558 ]]], dtype=float32)>

위에서 정의한 Wav2Vec2tokenizer 인스턴스를 사용하여 숫자를 텍스트 시퀀스로 다시 디코딩하겠습니다.

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['H']

이 예측은 모델이 이 노트북의 대용량 데이터에 대해 훈련된 적이 없기 때문에 상당히 무작위적입니다(이 노트북은 완전한 훈련을 수행하기 위한 것이 아니기 때문). 완전한 LibriSpeech 데이터세트에서 이 모델을 훈련하면 좋은 예측을 얻을 수 있습니다.

마침내, 우리는 이 노트북의 마지막 부분에 도달했습니다. 그러나 이것으로 음성 관련 작업을 위한 TensorFlow 학습이 끝난 것은 아닙니다. 이 리포지토리에는 더 놀라운 학습 내용이 포함되어 있습니다. 이 노트북에 버그가 있음을 발견한 경우 여기에 문제를 보고하세요.