Join TensorFlow at Google I/O, May 11-12 Register now

Retrain a speech recognition model with TensorFlow Lite Model Maker

View on Run in Google Colab View source on GitHub Download notebook

In this colab notebook, you'll learn how to use the TensorFlow Lite Model Maker to train a speech recognition model that can classify spoken words or short phrases using one-second sound samples. The Model Maker library uses transfer learning to retrain an existing TensorFlow model with a new dataset, which reduces the amount of sample data and time required for training.

By default, this notebook retrains the model (BrowserFft, from the TFJS Speech Command Recognizer) using a subset of words from the speech commands dataset (such as "up," "down," "left," and "right"). Then it exports a TFLite model that you can run on a mobile device or embedded system (such as a Raspberry Pi). It also exports the trained model as a TensorFlow SavedModel.

This notebook is also designed to accept a custom dataset of WAV files, uploaded to Colab in a ZIP file. The more samples you have for each class, the better your accuracy will be, but because the transfer learning process uses feature embeddings from the pre-trained model, you can still get a fairly accurate model with only a few dozen samples in each of your classes.

If you want to run the notebook with the default speech dataset, you can run the whole thing now by clicking Runtime > Run all in the Colab toolbar. However, if you want to use your own dataset, then continue down to Prepare the dataset and follow the instructions there.

Import the required packages

You'll need TensorFlow, TFLite Model Maker, and some modules for audio manipulation, playback, and visualizations.

sudo apt -y install libportaudio2
pip install tflite-model-maker
import os
import glob
import random
import shutil

import librosa
import soundfile as sf
from IPython.display import Audio
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
from tflite_model_maker.config import ExportFormat

print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")
TensorFlow Version: 2.8.0
Model Maker Version: 0.4.0

Prepare the dataset

To train with the default speech dataset, just run all the code below as-is.

But if you want to train with your own speech dataset, follow these steps:

  1. Be sure each sample in your dataset is in WAV file format, about one second long. Then create a ZIP file with all your WAV files, organized into separate subfolders for each classification. For example, each sample for a speech command "yes" should be in a subfolder named "yes". Even if you have only one class, the samples must be saved in a subdirectory with the class name as the directory name. (This script assumes your dataset is not split into train/validation/test sets and performs that split for you.)
  2. Click the Files tab in the left panel and just drag-drop your ZIP file there to upload it.
  3. Use the following drop-down option to set use_custom_dataset to True.
  4. Then skip to Prepare a custom audio dataset to specify your ZIP filename and dataset directory name.

Generate a background noise dataset

Whether you're using the default speech dataset or a custom dataset, you should have a good set of background noises so your model can distinguish speech from other noises (including silence).

Because the following background samples are provided in WAV files that are a minute long or longer, we need to split them up into smaller one-second samples so we can reserve some for our test dataset. We'll also combine a couple different sample sources to build a comprehensive set of background noises and silence:

Downloading data from
1489100800/1489096277 [==============================] - 7s 0us/step
1489108992/1489096277 [==============================] - 7s 0us/step
Downloading data from
1073152/1072437 [==============================] - 0s 0us/step
1081344/1072437 [==============================] - 0s 0us/step
# Create a list of all the background wav files
files = glob.glob(os.path.join('./dataset-speech/_background_noise_', '*.wav'))
files = files + glob.glob(os.path.join('./dataset-background', '*.wav'))

background_dir = './background'
os.makedirs(background_dir, exist_ok=True)

# Loop through all files and split each into several one-second wav files
for file in files:
  filename = os.path.basename(os.path.normpath(file))
  print('Splitting', filename)
  name = os.path.splitext(filename)[0]
  rate = librosa.get_samplerate(file)
  length = round(librosa.get_duration(filename=file))
  for i in range(length - 1):
    start = i * rate
    stop = (i * rate) + rate
    data, _ =, start=start, stop=stop)
    sf.write(os.path.join(background_dir, name + str(i) + '.wav'), data, rate)
Splitting dude_miaowing.wav
Splitting white_noise.wav
Splitting exercise_bike.wav
Splitting pink_noise.wav
Splitting doing_the_dishes.wav
Splitting running_tap.wav
Splitting throat_clearing.wav
Splitting silence.wav

Prepare the speech commands dataset

We already downloaded the speech commands dataset, so now we just need to prune the number of classes for our model.

This dataset includes over 30 speech command classifications, and most of them have over 2,000 samples. But because we're using transfer learning, we don't need that many samples. So the following code does a few things:

  • Specify which classifications we want to use, and delete the rest.
  • Keep only 150 samples of each class for training (to prove that transfer learning works well with smaller datasets and simply to reduce the training time).
  • Create a separate directory for a test dataset so we can easily run inference with them later.
if not use_custom_dataset:
  commands = [ "up", "down", "left", "right", "go", "stop", "on", "off", "background"]
  dataset_dir = './dataset-speech'
  test_dir = './dataset-test'

  # Move the processed background samples
  shutil.move(background_dir, os.path.join(dataset_dir, 'background'))   

  # Delete all directories that are not in our commands list
  dirs = glob.glob(os.path.join(dataset_dir, '*/'))
  for dir in dirs:
    name = os.path.basename(os.path.normpath(dir))
    if name not in commands:

  # Count is per class
  sample_count = 150
  test_data_ratio = 0.2
  test_count = round(sample_count * test_data_ratio)

  # Loop through child directories (each class of wav files)
  dirs = glob.glob(os.path.join(dataset_dir, '*/'))
  for dir in dirs:
    files = glob.glob(os.path.join(dir, '*.wav'))
    # Move test samples:
    for file in files[sample_count:sample_count + test_count]:
      class_dir = os.path.basename(os.path.normpath(dir))
      os.makedirs(os.path.join(test_dir, class_dir), exist_ok=True)
      os.rename(file, os.path.join(test_dir, class_dir, os.path.basename(file)))
    # Delete remaining samples
    for file in files[sample_count + test_count:]:

Prepare a custom dataset

If you want to train the model with our own speech dataset, you need to upload your samples as WAV files in a ZIP (as described above) and modify the following variables to specify your dataset:

if use_custom_dataset:
  # Specify the ZIP file you uploaded:
  # Specify the unzipped path to your custom dataset
  # (this path contains all the subfolders with classification names):
  dataset_dir = './YOUR-DIRNAME'

After changing the filename and path name above, you're ready to train the model with your custom dataset. In the Colab toolbar, select Runtime > Run all to run the whole notebook.

The following code integrates our new background noise samples into your dataset and then separates a portion of all samples to create a test set.

def move_background_dataset(dataset_dir):
  dest_dir = os.path.join(dataset_dir, 'background')
  if os.path.exists(dest_dir):
    files = glob.glob(os.path.join(background_dir, '*.wav'))
    for file in files:
      shutil.move(file, dest_dir)
    shutil.move(background_dir, dest_dir)
if use_custom_dataset:
  # Move background samples into custom dataset

  # Now we separate some of the files that we'll use for testing:
  test_dir = './dataset-test'
  test_data_ratio = 0.2
  dirs = glob.glob(os.path.join(dataset_dir, '*/'))
  for dir in dirs:
    files = glob.glob(os.path.join(dir, '*.wav'))
    test_count = round(len(files) * test_data_ratio)
    # Move test samples:
    for file in files[:test_count]:
      class_dir = os.path.basename(os.path.normpath(dir))
      os.makedirs(os.path.join(test_dir, class_dir), exist_ok=True)
      os.rename(file, os.path.join(test_dir, class_dir, os.path.basename(file)))
    print('Moved', test_count, 'images from', class_dir)

Play a sample

To be sure the dataset looks correct, let's play at a random sample from the test set:

def get_random_audio_file(samples_dir):
  files = os.path.abspath(os.path.join(samples_dir, '*/*.wav'))
  files_list = glob.glob(files)
  random_audio_path = random.choice(files_list)
  return random_audio_path

def show_sample(audio_path):
  audio_data, sample_rate =
  class_name = os.path.basename(os.path.dirname(audio_path))
  print(f'Class: {class_name}')
  print(f'File: {audio_path}')
  print(f'Sample rate: {sample_rate}')
  print(f'Sample length: {len(audio_data)}')

  display(Audio(audio_data, rate=sample_rate))
random_audio = get_random_audio_file(test_dir)
Class: off
File: /tmpfs/src/temp/tensorflow/lite/g3doc/tutorials/dataset-test/off/b4bef564_nohash_0.wav
Sample rate: 16000
Sample length: 16000


Define the model

When using Model Maker to retrain any model, you have to start by defining a model spec. The spec defines the base model from which your new model will extract feature embeddings to begin learning new classes. The spec for this speech recognizer is based on the pre-trained BrowserFft model from TFJS.

The model expects input as an audio sample that's 44.1 kHz, and just under a second long: the exact sample length must be 44034 frames.

You don't need to do any resampling with your training dataset. Model Maker takes care of that for you. But when you later run inference, you must be sure that your input matches that expected format.

All you need to do here is instantiate the BrowserFftSpec:

spec = audio_classifier.BrowserFftSpec()
INFO:tensorflow:Checkpoints are stored in /tmpfs/tmp/tmp0wkfe575
Downloading data from
24576/18467 [=======================================] - 0s 0us/step
32768/18467 [=====================================================] - 0s 0us/step
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with or tf.keras.models.save_model(), *NOT* To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Downloading data from
16384/203 [=====================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 0us/step
Downloading data from
16384/5466 [=========================================================================================] - 0s 0us/step
Downloading data from
4194304/4194304 [==============================] - 0s 0us/step
4202496/4194304 [==============================] - 0s 0us/step
Downloading data from
1687552/1680432 [==============================] - 0s 0us/step
1695744/1680432 [==============================] - 0s 0us/step

Load your dataset

Now you need to load your dataset according to the model specifications. Model Maker includes the DataLoader API, which will load your dataset from a folder and ensure it's in the expected format for the model spec.

We already reserved some test files by moving them to a separate directory, which makes it easier to run inference with them later. Now we'll create a DataLoader for each split: the training set, the validation set, and the test set.

Load the speech commands dataset

if not use_custom_dataset:
  train_data_ratio = 0.8
  train_data = audio_classifier.DataLoader.from_folder(
      spec, dataset_dir, cache=True)
  train_data, validation_data = train_data.split(train_data_ratio)
  test_data = audio_classifier.DataLoader.from_folder(
      spec, test_dir, cache=True)

Load a custom dataset

if use_custom_dataset:
  train_data_ratio = 0.8
  train_data = audio_classifier.DataLoader.from_folder(
      spec, dataset_dir, cache=True)
  train_data, validation_data = train_data.split(train_data_ratio)
  test_data = audio_classifier.DataLoader.from_folder(
      spec, test_dir, cache=True)

Train the model

Now we'll use the Model Maker create() function to create a model based on our model spec and training dataset, and begin training.

If you're using a custom dataset, you might want to change the batch size as appropriate for the number of samples in your train set.

# If your dataset has fewer than 100 samples per class,
# you might want to try a smaller batch size
batch_size = 25
epochs = 25
model = audio_classifier.create(train_data, spec, validation_data, batch_size, epochs)
Model: "sequential_1"
 Layer (type)                Output Shape              Param #   
 conv2d_1 (Conv2D)           (None, 42, 225, 8)        136       
 max_pooling2d_1 (MaxPooling  (None, 21, 112, 8)       0         
 conv2d_2 (Conv2D)           (None, 20, 109, 32)       2080      
 max_pooling2d_2 (MaxPooling  (None, 10, 54, 32)       0         
 conv2d_3 (Conv2D)           (None, 9, 51, 32)         8224      
 max_pooling2d_3 (MaxPooling  (None, 4, 25, 32)        0         
 conv2d_4 (Conv2D)           (None, 3, 22, 32)         8224      
 max_pooling2d_4 (MaxPooling  (None, 2, 11, 32)        0         
 flatten_1 (Flatten)         (None, 704)               0         
 dropout_1 (Dropout)         (None, 704)               0         
 dense_1 (Dense)             (None, 2000)              1410000   
 dropout_2 (Dropout)         (None, 2000)              0         
 classification_head (Dense)  (None, 9)                18009     
Total params: 1,446,673
Trainable params: 18,009
Non-trainable params: 1,428,664
Epoch 1/25
40/40 [==============================] - 13s 251ms/step - loss: 2.0813 - acc: 0.4690 - val_loss: 0.4588 - val_acc: 0.8542
Epoch 2/25
40/40 [==============================] - 0s 11ms/step - loss: 0.7247 - acc: 0.7731 - val_loss: 0.3289 - val_acc: 0.8958
Epoch 3/25
40/40 [==============================] - 0s 10ms/step - loss: 0.5002 - acc: 0.8352 - val_loss: 0.2854 - val_acc: 0.9083
Epoch 4/25
40/40 [==============================] - 0s 11ms/step - loss: 0.3794 - acc: 0.8881 - val_loss: 0.2456 - val_acc: 0.9250
Epoch 5/25
40/40 [==============================] - 0s 10ms/step - loss: 0.3247 - acc: 0.9023 - val_loss: 0.2621 - val_acc: 0.9208
Epoch 6/25
40/40 [==============================] - 0s 10ms/step - loss: 0.3166 - acc: 0.9013 - val_loss: 0.2425 - val_acc: 0.9208
Epoch 7/25
40/40 [==============================] - 0s 10ms/step - loss: 0.2929 - acc: 0.9044 - val_loss: 0.2398 - val_acc: 0.9208
Epoch 8/25
40/40 [==============================] - 0s 10ms/step - loss: 0.2074 - acc: 0.9329 - val_loss: 0.2221 - val_acc: 0.9292
Epoch 9/25
40/40 [==============================] - 0s 10ms/step - loss: 0.2236 - acc: 0.9318 - val_loss: 0.2228 - val_acc: 0.9417
Epoch 10/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1996 - acc: 0.9339 - val_loss: 0.2083 - val_acc: 0.9333
Epoch 11/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1699 - acc: 0.9420 - val_loss: 0.2212 - val_acc: 0.9417
Epoch 12/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1520 - acc: 0.9410 - val_loss: 0.2145 - val_acc: 0.9333
Epoch 13/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1410 - acc: 0.9583 - val_loss: 0.2163 - val_acc: 0.9333
Epoch 14/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1507 - acc: 0.9451 - val_loss: 0.2174 - val_acc: 0.9292
Epoch 15/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1243 - acc: 0.9563 - val_loss: 0.2225 - val_acc: 0.9375
Epoch 16/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1181 - acc: 0.9593 - val_loss: 0.2064 - val_acc: 0.9375
Epoch 17/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1074 - acc: 0.9654 - val_loss: 0.2097 - val_acc: 0.9333
Epoch 18/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1087 - acc: 0.9634 - val_loss: 0.2067 - val_acc: 0.9458
Epoch 19/25
40/40 [==============================] - 0s 11ms/step - loss: 0.0966 - acc: 0.9654 - val_loss: 0.2079 - val_acc: 0.9375
Epoch 20/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1000 - acc: 0.9644 - val_loss: 0.1954 - val_acc: 0.9458
Epoch 21/25
40/40 [==============================] - 0s 10ms/step - loss: 0.1034 - acc: 0.9624 - val_loss: 0.2080 - val_acc: 0.9375
Epoch 22/25
40/40 [==============================] - 0s 10ms/step - loss: 0.0809 - acc: 0.9705 - val_loss: 0.2114 - val_acc: 0.9375
Epoch 23/25
40/40 [==============================] - 0s 10ms/step - loss: 0.0721 - acc: 0.9715 - val_loss: 0.2159 - val_acc: 0.9333
Epoch 24/25
40/40 [==============================] - 0s 11ms/step - loss: 0.0804 - acc: 0.9756 - val_loss: 0.2218 - val_acc: 0.9375
Epoch 25/25
40/40 [==============================] - 0s 10ms/step - loss: 0.0892 - acc: 0.9746 - val_loss: 0.2217 - val_acc: 0.9417

Review the model performance

Even if the accuracy/loss looks good from the training output above, it's important to also run the model using test data that the model has not seen yet, which is what the evaluate() method does here:

8/8 [==============================] - 2s 223ms/step - loss: 0.3017 - acc: 0.9184
[0.301722913980484, 0.918367326259613]

View the confusion matrix

When training a classification model such as this one, it's also useful to inspect the confusion matrix. The confusion matrix gives you detailed visual representation of how well your classifier performs for each classification in your test data.

def show_confusion_matrix(confusion, test_labels):
  """Compute confusion matrix and normalize."""
  confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
  sns.set(rc = {'figure.figsize':(6,6)})
      confusion_normalized, xticklabels=test_labels, yticklabels=test_labels,
      cmap='Blues', annot=True, fmt='.2f', square=True, cbar=False)
  plt.title("Confusion matrix")
  plt.ylabel("True label")
  plt.xlabel("Predicted label")

confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)


Export the model

The last step is exporting your model into the TensorFlow Lite format for execution on mobile/embedded devices and into the SavedModel format for execution elsewhere.

When exporting a .tflite file from Model Maker, it includes model metadata that describes various details that can later help during inference. It even includes a copy of the classification labels file, so you don't need to a separate labels.txt file. (In the next section, we show how to use this metadata to run an inference.)

TFLITE_FILENAME = 'browserfft-speech.tflite'
SAVE_PATH = './models'
print(f'Exporing the model to {SAVE_PATH}')
model.export(SAVE_PATH, tflite_filename=TFLITE_FILENAME)
model.export(SAVE_PATH, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])
Exporing the model to ./models
2022-05-11 00:27:22.800934: W tensorflow/python/util/] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp4pga14bh/assets
INFO:tensorflow:TensorFlow Lite model exported successfully: ./models/browserfft-speech.tflite
2022-05-11 00:27:25.077395: W tensorflow/compiler/mlir/lite/python/] Ignored output_format.
2022-05-11 00:27:25.077446: W tensorflow/compiler/mlir/lite/python/] Ignored drop_control_dependency.
INFO:tensorflow:TensorFlow Lite model exported successfully: ./models/browserfft-speech.tflite
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./models/saved_model/assets
INFO:tensorflow:Assets written to: ./models/saved_model/assets
INFO:tensorflow:Saving labels in ./models/labels.txt
INFO:tensorflow:Saving labels in ./models/labels.txt

Run inference with TF Lite model

Now your TFLite model can be deployed and run using any of the supported inferencing libraries or with the new TFLite AudioClassifier Task API. The following code shows how you can run inference with the .tflite model in Python.

# This library provides the TFLite metadata API
 pip install -q tflite_support
from tflite_support import metadata
import json

def get_labels(model):
  """Returns a list of labels, extracted from the model metadata."""
  displayer = metadata.MetadataDisplayer.with_model_file(model)
  labels_file = displayer.get_packed_associated_file_list()[0]
  labels = displayer.get_associated_file_buffer(labels_file).decode()
  return [line for line in labels.split('\n')]

def get_input_sample_rate(model):
  """Returns the model's expected sample rate, from the model metadata."""
  displayer = metadata.MetadataDisplayer.with_model_file(model)
  metadata_json = json.loads(displayer.get_metadata_json())
  input_tensor_metadata = metadata_json['subgraph_metadata'][0][
  input_content_props = input_tensor_metadata['content']['content_properties']
  return input_content_props['sample_rate']

To observe how well the model performs with real samples, run the following code block over and over. Each time, it will fetch a new test sample and run inference with it, and you can listen to the audio sample below.

# Get a WAV file for inference and list of labels from the model
tflite_file = os.path.join(SAVE_PATH, TFLITE_FILENAME)
labels = get_labels(tflite_file)
random_audio = get_random_audio_file(test_dir)

# Ensure the audio sample fits the model input
interpreter = tf.lite.Interpreter(tflite_file)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_size = input_details[0]['shape'][1]
sample_rate = get_input_sample_rate(tflite_file)
audio_data, _ = librosa.load(random_audio, sr=sample_rate)
if len(audio_data) < input_size:
audio_data = np.expand_dims(audio_data[:input_size], axis=0)

# Run inference
interpreter.set_tensor(input_details[0]['index'], audio_data)
output_data = interpreter.get_tensor(output_details[0]['index'])

# Display prediction and ground truth
top_index = np.argmax(output_data[0])
label = labels[top_index]
score = output_data[0][top_index]
print(f'Class: {label}\nScore: {score}')
Class: stop
Score: 0.7824302911758423
Class: stop
File: /tmpfs/src/temp/tensorflow/lite/g3doc/tutorials/dataset-test/stop/c08e5058_nohash_0.wav
Sample rate: 16000
Sample length: 16000
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


Download the TF Lite model

Now you can deploy the TF Lite model to your mobile or embedded device. You don't need to download the labels file because you can instead retrieve the labels from .tflite file metadata, as shown in the previous inferencing example.

  from google.colab import files
except ImportError:

Check out our end-to-end example apps that perform inferencing with TFLite audio models on Android and iOS.