Creates MobileBert model spec for the text classification task. See also: tflite_model_maker.text_classifier.BertClassifierSpec
.
tflite_model_maker.text_classifier.MobileBertClassifierSpec(
*,
uri='https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
model_dir=None,
seq_len=128,
dropout_rate=0.1,
initializer_range=0.02,
learning_rate=3e-05,
distribution_strategy='off',
num_gpus=-1,
tpu='',
trainable=True,
do_lower_case=True,
is_tf2=False,
name='MobileBert',
tflite_input_name=None,
default_batch_size=48,
index_to_label=None
)
Args | |
---|---|
uri
|
TF-Hub path/url to Bert module. |
model_dir
|
The location of the model checkpoint files. |
seq_len
|
Length of the sequence to feed into the model. |
dropout_rate
|
The rate for dropout. |
initializer_range
|
The stdev of the truncated_normal_initializer for initializing all weight matrices. |
learning_rate
|
The initial learning rate for Adam. |
distribution_strategy
|
A string specifying which distribution strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means
to use TPUStrategy using tpu_address .
|
num_gpus
|
How many GPUs to use at each worker with the DistributionStrategies API. The default is -1, which means utilize all available GPUs. |
tpu
|
TPU address to connect to. |
trainable
|
boolean, whether pretrain layer is trainable. |
do_lower_case
|
boolean, whether to lower case the input text. Should be True for uncased models and False for cased models. |
is_tf2
|
boolean, whether the hub module is in TensorFlow 2.x format. |
name
|
The name of the object. |
tflite_input_name
|
Dict, input names for the TFLite model. |
default_batch_size
|
Default batch size for training. |
index_to_label
|
List of labels in the training data. e.g. ['neg', 'pos']. |