Extracts the next sentence label from sentences.
tfm.nlp.ops.get_next_sentence_labels(
sentences, random_threshold=0.5, random_fn=tf.random.uniform
)
Args |
sentences
|
A RaggedTensor of strings w/ shape [batch, (num_sentences)].
|
random_threshold
|
(optional) A float threshold between 0 and 1, used to
determine whether to extract a random sentence or the immediate next
sentence. Higher value favors next sentence.
|
random_fn
|
(optional) An op used to generate random float values.
|
Returns |
A tuple of (next_sentence_or_random, is_next_sentence) where:
|
next_sentence_or_random
|
A Tensor with shape [num_sentences] that
contains either the subsequent sentence of segment_a or a randomly
injected sentence.
|
is_next_sentence
|
A Tensor of bool w/ shape [num_sentences]
that contains whether or not next_sentence_or_random is truly a
subsequent sentence or not.
|