Fine-tuning Wav2Vec2 with an LM head

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook See TF Hub model

In this notebook, we will load the pre-trained wav2vec2 model from TFHub and will fine-tune it on LibriSpeech dataset by appending Language Modeling head (LM) over the top of our pre-trained model. The underlying task is to build a model for Automatic Speech Recognition i.e. given some speech, the model should be able to transcribe it into text.

Setting Up

Before running this notebook, please ensure that you are on GPU runtime (Runtime > Change runtime type > GPU). The following cell will install gsoc-wav2vec2 package & its dependencies.

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
WARNING: You are using pip version 21.2.3; however, version 21.2.4 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.



The following packages were automatically installed and are no longer required:
  linux-gcp-5.4-headers-5.4.0-1040 linux-gcp-5.4-headers-5.4.0-1043
  linux-gcp-5.4-headers-5.4.0-1044 linux-headers-5.4.0-1043-gcp
  linux-headers-5.4.0-1044-gcp linux-image-5.4.0-1044-gcp
  linux-modules-5.4.0-1044-gcp linux-modules-extra-5.4.0-1044-gcp
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libflac8 libogg-dev libogg0 libsndfile1 libvorbis-dev
  libvorbis0a libvorbisenc2 libvorbisfile3
The following NEW packages will be installed:
  libflac-dev libflac8 libogg-dev libogg0 libsndfile1 libsndfile1-dev
  libvorbis-dev libvorbis0a libvorbisenc2 libvorbisfile3
0 upgraded, 10 newly installed, 0 to remove and 112 not upgraded.
Need to get 1597 kB of archives.
After this operation, 6532 kB of additional disk space will be used.
Get:1 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libogg0 amd64 1.3.2-1 [17.2 kB]
Get:2 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libflac8 amd64 1.3.2-1 [213 kB]
Get:3 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libogg-dev amd64 1.3.2-1 [156 kB]
Get:4 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libflac-dev amd64 1.3.2-1 [260 kB]
Get:5 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbis0a amd64 1.3.5-4.2 [86.4 kB]
Get:6 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbisenc2 amd64 1.3.5-4.2 [70.7 kB]
Get:7 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic-updates/main amd64 libsndfile1 amd64 1.0.28-4ubuntu0.18.04.2 [170 kB]
Get:8 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbisfile3 amd64 1.3.5-4.2 [16.0 kB]
Get:9 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbis-dev amd64 1.3.5-4.2 [321 kB]
Get:10 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic-updates/main amd64 libsndfile1-dev amd64 1.0.28-4ubuntu0.18.04.2 [287 kB]
Fetched 1597 kB in 1s (1657 kB/s)
Selecting previously unselected package libogg0:amd64.
(Reading database ... 275832 files and directories currently installed.)
Preparing to unpack .../0-libogg0_1.3.2-1_amd64.deb ...
Unpacking libogg0:amd64 (1.3.2-1) ...
Selecting previously unselected package libflac8:amd64.
Preparing to unpack .../1-libflac8_1.3.2-1_amd64.deb ...
Unpacking libflac8:amd64 (1.3.2-1) ...
Selecting previously unselected package libogg-dev:amd64.
Preparing to unpack .../2-libogg-dev_1.3.2-1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../3-libflac-dev_1.3.2-1_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libvorbis0a:amd64.
Preparing to unpack .../4-libvorbis0a_1.3.5-4.2_amd64.deb ...
Unpacking libvorbis0a:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libvorbisenc2:amd64.
Preparing to unpack .../5-libvorbisenc2_1.3.5-4.2_amd64.deb ...
Unpacking libvorbisenc2:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libsndfile1:amd64.
Preparing to unpack .../6-libsndfile1_1.0.28-4ubuntu0.18.04.2_amd64.deb ...
Unpacking libsndfile1:amd64 (1.0.28-4ubuntu0.18.04.2) ...
Selecting previously unselected package libvorbisfile3:amd64.
Preparing to unpack .../7-libvorbisfile3_1.3.5-4.2_amd64.deb ...
Unpacking libvorbisfile3:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../8-libvorbis-dev_1.3.5-4.2_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../9-libsndfile1-dev_1.0.28-4ubuntu0.18.04.2_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Setting up libogg0:amd64 (1.3.2-1) ...
Setting up libvorbis0a:amd64 (1.3.5-4.2) ...
Setting up libogg-dev:amd64 (1.3.2-1) ...
Setting up libvorbisfile3:amd64 (1.3.5-4.2) ...
Setting up libflac8:amd64 (1.3.2-1) ...
Setting up libvorbisenc2:amd64 (1.3.5-4.2) ...
Setting up libvorbis-dev:amd64 (1.3.5-4.2) ...
Setting up libflac-dev:amd64 (1.3.2-1) ...
Setting up libsndfile1:amd64 (1.0.28-4ubuntu0.18.04.2) ...
Setting up libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Processing triggers for libc-bin (2.27-3ubuntu1.2) ...
WARNING: You are using pip version 21.2.3; however, version 21.2.4 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

Model setup using TFHub

We will start by importing some libraries/modules.

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
TF version: 2.6.0

First, we will download our model from TFHub & will wrap our model signature with hub.KerasLayer to be able to use this model like any other Keras layer. Fortunately, hub.KerasLayer can do both in just 1 line.

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

You can refer to this script in case you are interested in the model exporting script. Object pretrained_layer is the freezed version of Wav2Vec2Model. These pre-trained weights were converted from HuggingFace PyTorch pre-trained weights using this script.

Originally, wav2vec2 was pre-trained with a masked language modelling approach with the objective to identify the true quantized latent speech representation for a masked time step. You can read more about the training objective in the paper- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.

Now, we will be defining a few constants and hyper-parameters which will be useful in the next few cells. AUDIO_MAXLEN is intentionally set to 246000 as the model signature only accepts static sequence length of 246000.

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

In the following cell, we will wrap pretrained_layer & a dense layer (LM head) with the Keras's Functional API.

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

The dense layer (defined above) is having an output dimension of vocab_size as we want to predict probabilities of each token in the vocabulary at each time step.

Setting up training state

In TensorFlow, model weights are built only when model.call or model.build is called for the first time, so the following cell will build the model weights for us. Further, we will be running model.summary() to check the total number of trainable parameters.

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 246000)]          0         
_________________________________________________________________
keras_layer (KerasLayer)     (None, 768, 768)          94371712  
_________________________________________________________________
dense (Dense)                (None, 768, 32)           24608     
=================================================================
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________

Now, we need to define the loss_fn and optimizer to be able to train the model. The following cell will do that for us. We will be using the Adam optimizer for simplicity. CTCLoss is a common loss type that is used for tasks (like ASR) where input sub-parts can't be easily aligned with output sub-parts. You can read more about CTC-loss from this amazing blog post.

CTCLoss (from gsoc-wav2vec2 package) accepts 3 arguments: config, model_input_shape & division_factor. If division_factor=1, then loss will simply get summed, so pass division_factor accordingly to get mean over batch.

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

Loading & Pre-processing data

Let's now download the LibriSpeech dataset from the official website and set it up.

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2021-08-26 11:07:41--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  11.4MB/s    in 30s     

2021-08-26 11:08:12 (10.7 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]
ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

Our dataset lies in the LibriSpeech directory. Let's explore these files.

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0002.flac', '2428-83705-0036.flac', '2428-83705-0014.flac', '2428-83705-0043.flac', '2428-83705-0003.flac', '2428-83705-0016.flac', '2428-83705-0009.flac', '2428-83705-0012.flac', '2428-83705-0004.flac', '2428-83705-0028.flac', '2428-83705-0041.flac', '2428-83705-0039.flac', '2428-83705-0031.flac', '2428-83705-0022.flac', '2428-83705-0011.flac', '2428-83705-0020.flac', '2428-83705-0027.flac', '2428-83705-0018.flac', '2428-83705-0032.flac', '2428-83705-0034.flac', '2428-83705-0017.flac', '2428-83705-0005.flac', '2428-83705-0006.flac', '2428-83705-0010.flac', '2428-83705-0021.flac', '2428-83705-0000.flac', '2428-83705-0023.flac', '2428-83705-0042.flac', '2428-83705-0035.flac', '2428-83705-0038.flac', '2428-83705-0025.flac', '2428-83705-0008.flac', '2428-83705-0030.flac', '2428-83705-0037.flac', '2428-83705-0029.flac', '2428-83705-0033.flac', '2428-83705-0040.flac', '2428-83705-0026.flac', '2428-83705-0007.flac', '2428-83705-0013.flac', '2428-83705-0015.flac', '2428-83705-0001.flac', '2428-83705-0019.flac', '2428-83705-0024.flac']

Alright, so each sub-directory has many .flac files and a .txt file. The .txt file contains text transcriptions for all the speech samples (i.e. .flac files) present in that sub-directory.

We can load this text data as follows:

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

Similarly, we will define a function for loading a speech sample from a .flac file.

REQUIRED_SAMPLE_RATE is set to 16000 as wav2vec2 was pre-trained with 16K frequency and it's recommended to fine-tune it without any major change in data distribution due to frequency.

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

Now, we will pick some random samples & will try to visualize them.

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: IT WAS PLAIN THAT TOGETHER WE SHOULD MANAGE MOST COMFORTABLY DELIGHTFULLY IN FACT 
Audio:

Now, we will combine all the speech & text samples and will define the function (in next cell) for that purpose.

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

It's time to have a look at a few samples ...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([-2.44140625e-04, -1.83105469e-04, -9.15527344e-05, ...,
         -1.22070312e-04, -1.83105469e-04, -1.52587891e-04]),
  'SOMEONE SNIGGERED'),
 (array([-0.00027466, -0.00033569, -0.00036621, ...,  0.00021362,

          0.        ,  0.        ]),
  'THERE SHE OWNS A COTTAGE OR IT MAY BE A PIGSTYE FOR ALL I KNOW'),
 (array([-0.00015259, -0.00012207, -0.00021362, ..., -0.00048828,
         -0.00048828, -0.00045776]),
  "BESIDES WHICH WE CAN ALWAYS SELL THE COUPONS AND RAILWAY PASSES WHICH WE DON'T USE"),
 (array([-1.22070312e-04,  3.05175781e-05,  6.10351562e-05, ...,
         -4.27246094e-04, -6.10351562e-04, -9.15527344e-04]),
  'IT IS MOST DELIGHTFUL'),
 (array([-0.00036621, -0.00015259, -0.00012207, ..., -0.0005188 ,
         -0.00048828, -0.00048828]),
  'THERE WERE NO SIGNS OF FALTERING ABOUT HER FLOW OF LANGUAGE')]

Let's pre-process the data now !!!

We will first define the tokenizer & processor using gsoc-wav2vec2 package. Then, we will do very simple pre-processing. processor will normalize raw speech w.r.to frames axis and tokenizer will convert our model outputs into the string (using the defined vocabulary) & will take care of the removal of special tokens (depending on your tokenizer configuration).

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

Now, we will define the python generator to call the preprocessing functions we defined in above cells.

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

Setting up tf.data.Dataset

Following cell will setup tf.data.Dataset object using its .from_generator(...) method. We will be using the generator object, we defined in the above cell.

You can refer to this script for more details on how to convert LibriSpeech data into tfrecords.

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

We will pass the dataset into multiple batches, so let's prepare batches in the following cell. Now, all the sequences in a batch should be padded to a constant length. We will use the.padded_batch(...) method for that purpose.

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

Accelerators (like GPUs/TPUs) are very fast and often data-loading (& pre-processing) becomes the bottleneck during training as the data-loading part happens on CPUs. This can increase the training time significantly especially when there is a lot of online pre-processing involved or data is streamed online from GCS buckets. To handle those issues, tf.data.Dataset offers the .prefetch(...) method. This method helps in preparing the next few batches in parallel (on CPUs) while the model is making predictions (on GPUs/TPUs) on the current batch.

dataset = dataset.prefetch(tf.data.AUTOTUNE)

Since this notebook is made for demonstration purposes, we will be taking first num_train_batches and will perform training over only that. You are encouraged to train on the whole dataset though. Similarly, we will evaluate only num_val_batches.

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

Model training

For training our model, we will be directly calling .fit(...) method after compiling our model with .compile(...).

model.compile(optimizer, loss=loss_fn)

The above cell will set up our training state. Now we can initiate training with the .fit(...) method.

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1442: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1442: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1425: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1425: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss.
10/10 [==============================] - 30s 2s/step - loss: 1085.6143 - val_loss: 479.0432
Epoch 2/3
10/10 [==============================] - 16s 2s/step - loss: 724.7572 - val_loss: 740.7762
Epoch 3/3
10/10 [==============================] - 16s 2s/step - loss: 672.7921 - val_loss: 514.6213
{'loss': [1085.6142578125, 724.7572021484375, 672.7920532226562],
 'val_loss': [479.043212890625, 740.7762451171875, 514.6212768554688]}

Let's save our model with .save(...) method to be able to perform inference later. You can also export this SavedModel to TFHub by following TFHub documentation.

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
2021-08-26 11:09:20.988517: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 855). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

Evaluation

Now we will be computing Word Error Rate over the validation dataset

Word error rate (WER) is a common metric for measuring the performance of an automatic speech recognition system. The WER is derived from the Levenshtein distance, working at the word level. Word error rate can then be computed as: WER = (S + D + I) / N = (S + D + I) / (S + D + C) where S is the number of substitutions, D is the number of deletions, I is the number of insertions, C is the number of correct words, N is the number of words in the reference (N=S+D+C). This value indicates the percentage of words that were incorrectly predicted.

You can refer to this paper to learn more about WER.

We will use load_metric(...) function from HuggingFace datasets library. Let's first install the datasets library using pip and then define the metric object.

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
WARNING: You are using pip version 21.2.3; however, version 21.2.4 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

It's time to run the evaluation on validation data now.

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2021-08-26 11:09:34.896906: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:54] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform

We are using the tokenizer.decode(...) method for decoding our predictions and labels back into the text and will add them to the metric for WER computation later.

Now, let's calculate the metric value in following cell:

metric.compute()
1.0

Inference

Now that we are satisfied with the training process & have saved the model in save_dir, we will see how this model can be used for inference.

First, we will load our model using tf.keras.models.load_model(...).

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

Let's download some speech samples for performing inference. You can replace the following sample with your speech sample also.

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2021-08-26 11:09:50--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 13.114.40.48
Connecting to github.com (github.com)|13.114.40.48|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2021-08-26 11:09:51--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.02s   

2021-08-26 11:09:51 (5.45 MB/s) - ‘SA2.wav’ saved [94252/94252]

Now, we will read the speech sample using soundfile.read(...) and pad it to AUDIO_MAXLEN to satisfy the model signature. Then we will normalize that speech sample using the Wav2Vec2Processor instance & will feed it into the model.

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 2.9764342 , -0.26707956, -0.5779591 , ..., -0.34974045,
         -0.5029342 , -1.0964708 ],
        [ 2.9764307 , -0.26708567, -0.57792413, ..., -0.3497205 ,
         -0.5029489 , -1.0964632 ],
        [ 2.976445  , -0.26707464, -0.57796115, ..., -0.34973744,
         -0.5029509 , -1.09643   ],
        ...,
        [ 2.9764574 , -0.26693058, -0.5780056 , ..., -0.34982246,
         -0.5031293 , -1.0964078 ],
        [ 2.9764593 , -0.26692826, -0.5780005 , ..., -0.34982204,
         -0.5031268 , -1.0964111 ],
        [ 2.9764667 , -0.26693505, -0.5779976 , ..., -0.34982124,
         -0.50311416, -1.0964153 ]]], dtype=float32)>

Let's decode numbers back into text sequence using the Wav2Vec2tokenizer instance, we defined above.

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['R']

This prediction is quite random as the model was never trained on large data in this notebook (as this notebook is not meant for doing complete training). You will get good predictions if you train this model on complete LibriSpeech dataset.

Finally, we have reached an end to this notebook. But it's not an end of learning TensorFlow for speech-related tasks, this repository contains some more amazing tutorials. In case you encountered any bug in this notebook, please create an issue here.