tfm.nlp.ops.get_next_sentence_labels

Extracts the next sentence label from sentences.

sentences A RaggedTensor of strings w/ shape [batch, (num_sentences)].
random_threshold (optional) A float threshold between 0 and 1, used to determine whether to extract a random sentence or the immediate next sentence. Higher value favors next sentence.
random_fn (optional) An op used to generate random float values.

A tuple of (next_sentence_or_random, is_next_sentence) where:
next_sentence_or_random A Tensor with shape [num_sentences] that contains either the subsequent sentence of segment_a or a randomly injected sentence.
is_next_sentence A Tensor of bool w/ shape [num_sentences] that contains whether or not next_sentence_or_random is truly a subsequent sentence or not.