Have a question? Connect with the community at the TensorFlow Forum Visit Forum

tfnlp.networks.XLNetSpanLabeling

Span labeling network head for XLNet on SQuAD2.0.

This networks implements a span-labeler based on dense layers and question possibility classification. This is the complex version seen in the original XLNet implementation.

This applies a dense layer to the input sequence data to predict the start positions, and then uses either the true start positions (if training) or beam search to predict the end positions.

input_width The innermost dimension of the input tensor to this network.
start_n_top Beam size for span start.
end_n_top Beam size for span end.
activation The activation, if any, for the dense layer in this network.
dropout_rate The dropout rate used for answer classification.
initializer The initializer for the dense layer in this network. Defaults to a Glorot uniform initializer.

Methods

call

View source

Implements call().

Einsum glossary:

  • b: the batch size.
  • l: the sequence length.
  • h: the hidden size, or input width.
  • k: the start/end top n.

Args
sequence_data The input sequence data of shape (batch_size, seq_length, input_width).
class_index The class indices of the inputs of shape (batch_size,).
paragraph_mask Invalid position mask such as query and special symbols (e.g. PAD, SEP, CLS) of shape (batch_size,).
start_positions The start positions of each example of shape (batch_size,).
training Whether or not this is the training phase.

Returns
A dictionary with the keys 'start_predictions', 'end_predictions', 'start_logits', 'end_logits'.

If inference, then 'start_top_predictions', 'start_top_index', 'end_top_predictions', 'end_top_index' are also included.

end_logits

View source

Computes the end logits.

Input shapes into the inner, layer norm, output layers should match.

During training, inputs shape should be [batch_size, seq_length, input_width].

During inference, input shapes should be [batch_size, seq_length, start_n_top, input_width].

Args
inputs The input for end logits.

Returns
Calculated end logits.