![]() |
Span labeling network head for XLNet on SQuAD2.0.
tfm.nlp.networks.XLNetSpanLabeling(
input_width,
start_n_top=5,
end_n_top=5,
activation='tanh',
dropout_rate=0.0,
initializer='glorot_uniform',
**kwargs
)
This networks implements a span-labeler based on dense layers and question possibility classification. This is the complex version seen in the original XLNet implementation.
This applies a dense layer to the input sequence data to predict the start positions, and then uses either the true start positions (if training) or beam search to predict the end positions.
Methods
call
call(
sequence_data,
class_index,
paragraph_mask=None,
start_positions=None,
training=False
)
Implements call().
Einsum glossary:
- b: the batch size.
- l: the sequence length.
- h: the hidden size, or input width.
- k: the start/end top n.
Args | |
---|---|
sequence_data
|
The input sequence data of shape
(batch_size, seq_length, input_width) .
|
class_index
|
The class indices of the inputs of shape (batch_size,) .
|
paragraph_mask
|
Invalid position mask such as query and special symbols
(e.g. PAD, SEP, CLS) of shape (batch_size,) .
|
start_positions
|
The start positions of each example of shape
(batch_size,) .
|
training
|
Whether or not this is the training phase. |
Returns | |
---|---|
A dictionary with the keys start_predictions , end_predictions ,
start_logits , end_logits .
If inference, then |
end_logits
end_logits(
inputs
)
Computes the end logits.
Input shapes into the inner, layer norm, output layers should match.
During training, inputs shape should be [batch_size, seq_length, input_width].
During inference, input shapes should be [batch_size, seq_length, start_n_top, input_width].
Args | |
---|---|
inputs
|
The input for end logits. |
Returns | |
---|---|
Calculated end logits. |