tflite_model_maker.question_answer.BertQaSpec

A specification of BERT model for question answering.

uri TF-Hub path/url to Bert module.
model_dir The location of the model checkpoint files.
seq_len Length of the sequence to feed into the model.
query_len Length of the query to feed into the model.
doc_stride The stride when we do a sliding window approach to take chunks of the documents.
dropout_rate The rate for dropout.
initializer_range The stdev of the truncated_normal_initializer for initializing all weight matrices.
learning_rate The initial learning rate for Adam.
distribution_strategy A string specifying which distribution strategy to use. Accepted values are 'off', 'one_device', 'mirrored', 'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to use TPUStrategy using tpu_address.
num_gpus How many GPUs to use at each worker with the DistributionStrategies API. The default is -1, which means utilize all available GPUs.
tpu TPU address to connect to.
trainable boolean, whether pretrain layer is trainable.
predict_batch_size Batch size for prediction.
do_lower_case boolean, whether to lower case the input text. Should be True for uncased models and False for cased models.
is_tf2 boolean, whether the hub module is in TensorFlow 2.x format.
tflite_input_name Dict, input names for the TFLite model.
tflite_output_name Dict, output names for the TFLite model.
init_from_squad_model boolean, whether to initialize from the model that is already retrained on Squad 1.1.
default_batch_size Default batch size for training.
name Name of the object.

Methods

build

View source

Builds the class. Used for lazy initialization.

convert_examples_to_features

View source

Converts examples to features and write them into TFRecord file.

create_model

View source

Creates the model for qa task.

evaluate

View source

Evaluate QA model.

Args
model The keras model to be evaluated.
tflite_filepath File path to the TFLite model.
dataset tf.data.Dataset used for evaluation.
num_steps Number of steps to evaluate the model.
eval_examples List of squad_lib.SquadExample for evaluation data.
eval_features List of squad_lib.InputFeatures for evaluation data.
predict_file The input predict file.
version_2_with_negative Whether the input predict file is SQuAD 2.0 format.
max_answer_length The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another.
null_score_diff_threshold If null_score - best_non_null is greater than the threshold, predict null. This is only used for SQuAD v2.
verbose_logging If true, all of the warnings related to data processing will be printed. A number of warnings are expected for a normal SQuAD evaluation.
output_dir The output directory to save output to json files: predictions.json, nbest_predictions.json, null_odds.json. If None, skip saving to json files.

Returns
A dict contains two metrics: Exact match rate and F1 score.

get_config

View source

Gets the configuration.

get_default_quantization_config

View source

Gets the default quantization configuration.

get_name_to_features

View source

Gets the dictionary describing the features.

predict

View source

Predicts the dataset for model.

predict_tflite

View source

Predicts the dataset for TFLite model in tflite_filepath.

reorder_input_details

View source

Reorders the tflite input details to map the order of keras model.

reorder_output_details

View source

Reorders the tflite output details to map the order of keras model.

save_vocab

View source

Prints the file path to the vocabulary.

select_data_from_record

View source

Dispatches records to features and labels.

train

View source

Run bert QA training.

Args
train_ds tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
epochs Integer, training epochs.
steps_per_epoch Integer or None. Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. If steps_per_epoch is None, the epoch will run until the input dataset is exhausted.
**kwargs Other parameters used in the tf.keras.Model.fit().

Returns
tf.keras.Model, the keras model that's already trained.

compat_tf_versions [2]
convert_from_saved_model_tf2 True
need_gen_vocab False