tflite_model_maker.audio_classifier.AudioClassifier

Audio classifier for training/inference and exporing.

model_spec Specification for the model.
index_to_label A list that map from index to label class name.
shuffle Whether the data should be shuffled.
train_whole_model If true, the Hub module is trained together with the classification layer on top. Otherwise, only train the top classification layer.

Methods

confusion_matrix

View source

create

View source

Loads data and retrains the model.

Args
train_data A instance of audio_dataloader.DataLoader class.
model_spec Specification for the model.
validation_data Validation DataLoader. If None, skips validation process.
batch_size Number of samples per training step. If use_hub_library is False, it represents the base learning rate when train batch size is 256 and it's linear to the batch size.
epochs Number of epochs for training.
model_dir The location of the model checkpoint files.
do_train Whether to run training.
train_whole_model Boolean. By default, only the classification head is trained. When True, the base model is also trained.

Returns
An instance based on AudioClassifier.

create_model

View source

create_serving_model

View source

Returns the underlining Keras model for serving.

evaluate

View source

Evaluates the model.

Args
data Data to be evaluated.
batch_size Number of samples per evaluation step.

Returns
The loss value and accuracy.

evaluate_tflite

View source

Evaluates the tflite model.

Args
tflite_filepath File path to the TFLite model.
data Data to be evaluated.
postprocess_fn Postprocessing function that will be applied to the output of lite_runner.run before calculating the probabilities.

Returns
The evaluation result of TFLite model - accuracy.

export

View source

Converts the retrained model based on export_format.

Args
export_dir The directory to save exported files.
tflite_filename File name to save tflite model. The full export path is {export_dir}/{tflite_filename}.
label_filename File name to save labels. The full export path is {export_dir}/{label_filename}.
vocab_filename File name to save vocabulary. The full export path is {export_dir}/{vocab_filename}.
saved_model_filename Path to SavedModel or H5 file to save the model. The full export path is {export_dir}/{saved_model_filename}/{saved_model.pb|assets|variables}.
tfjs_folder_name Folder name to save tfjs model. The full export path is {export_dir}/{tfjs_folder_name}.
export_format List of export format that could be saved_model, tflite, label, vocab.
**kwargs Other parameters like quantized_config for TFLITE model.

predict_top_k

View source

Predicts the top-k predictions.

Args
data Data to be evaluated. Either an instance of DataLoader or just raw data entries such TF tensor or numpy array.
k Number of top results to be predicted.
batch_size Number of samples per evaluation step.

Returns
top k results. Each one is (label, probability).

summary

View source

train

View source

ALLOWED_EXPORT_FORMAT (<ExportFormat.LABEL: 'LABEL'>, <ExportFormat.TFLITE: 'TFLITE'>, <ExportFormat.SAVED_MODEL: 'SAVED_MODEL'>)
DEFAULT_EXPORT_FORMAT <ExportFormat.TFLITE: 'TFLITE'>