![]() |
A specification of BERT model for text classification.
tflite_model_maker.text_classifier.BertClassifierSpec(
uri='https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1',
model_dir=None,
seq_len=128,
dropout_rate=0.1,
initializer_range=0.02,
learning_rate=3e-05,
distribution_strategy='mirrored',
num_gpus=-1,
tpu='',
trainable=True,
do_lower_case=True,
is_tf2=True,
name='Bert',
tflite_input_name=None,
default_batch_size=32,
index_to_label=None
)
Args | |
---|---|
uri
|
TF-Hub path/url to Bert module. |
model_dir
|
The location of the model checkpoint files. |
seq_len
|
Length of the sequence to feed into the model. |
dropout_rate
|
The rate for dropout. |
initializer_range
|
The stdev of the truncated_normal_initializer for initializing all weight matrices. |
learning_rate
|
The initial learning rate for Adam. |
distribution_strategy
|
A string specifying which distribution strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means
to use TPUStrategy using tpu_address .
|
num_gpus
|
How many GPUs to use at each worker with the DistributionStrategies API. The default is -1, which means utilize all available GPUs. |
tpu
|
TPU address to connect to. |
trainable
|
boolean, whether pretrain layer is trainable. |
do_lower_case
|
boolean, whether to lower case the input text. Should be True for uncased models and False for cased models. |
is_tf2
|
boolean, whether the hub module is in TensorFlow 2.x format. |
name
|
The name of the object. |
tflite_input_name
|
Dict, input names for the TFLite model. |
default_batch_size
|
Default batch size for training. |
index_to_label
|
List of labels in the training data. e.g. ['neg', 'pos']. |
Methods
build
build()
Builds the class. Used for lazy initialization.
convert_examples_to_features
convert_examples_to_features(
examples, tfrecord_file, label_names
)
Converts examples to features and write them into TFRecord file.
create_model
create_model(
num_classes, optimizer='adam', with_loss_and_metrics=True
)
Creates the keras model.
get_config
get_config()
Gets the configuration.
get_default_quantization_config
get_default_quantization_config()
Gets the default quantization configuration.
get_name_to_features
get_name_to_features()
Gets the dictionary describing the features.
reorder_input_details
reorder_input_details(
tflite_input_details
)
Reorders the tflite input details to map the order of keras model.
run_classifier
run_classifier(
train_ds, validation_ds, epochs, steps_per_epoch, num_classes, **kwargs
)
Creates classifier and runs the classifier training.
Args | |
---|---|
train_ds
|
tf.data.Dataset, training data to be fed in tf.keras.Model.fit(). |
validation_ds
|
tf.data.Dataset, validation data to be fed in tf.keras.Model.fit(). |
epochs
|
Integer, training epochs. |
steps_per_epoch
|
Integer or None. Total number of steps (batches of
samples) before declaring one epoch finished and starting the next
epoch. If steps_per_epoch is None, the epoch will run until the input
dataset is exhausted.
|
num_classes
|
Interger, number of classes. |
**kwargs
|
Other parameters used in the tf.keras.Model.fit(). |
Returns | |
---|---|
tf.keras.Model, the keras model that's already trained. |
save_vocab
save_vocab(
vocab_filename
)
Prints the file path to the vocabulary.
select_data_from_record
select_data_from_record(
record
)
Dispatches records to features and labels.
Class Variables | |
---|---|
compat_tf_versions |
[2]
|
convert_from_saved_model_tf2 |
True
|
need_gen_vocab |
False
|