ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

text.combine_segments

Combine one or more input segments for a model's input sequence.

Used in the notebooks

Used in the guide

combine_segments combines the tokens of one or more input segments to a single sequence of token values and generates matching segment ids. combine_segments can follow a Trimmer, who limit segment lengths and emit RaggedTensor outputs, and can be followed up by ModelInputPacker.

See Detailed Experimental Setup in BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (https://arxiv.org/pdf/1810.04805.pdf) for more examples of combined segments.

combine_segments first flattens and combines a list of one or more segments (RaggedTensors of n dimensions) together along the 1st axis, then packages any special tokens into a final n dimensional RaggedTensor.

And finally combine_segments generates another RaggedTensor (with the same rank as the final combined RaggedTensor) that contains a distinct int id for each segment.

Example usage:

segment_a = [[1, 2],
             [3, 4,],
             [5, 6, 7, 8, 9]]

segment_b = [[10, 20,],
             [30, 40, 50, 60,],
             [70, 80]]
expected_combined, expected_ids = combine_segments([segment_a, segment_b])

# segment_a and segment_b have been combined w/ special tokens describing
# the beginning of a sequence and end of a sequence inserted.
expected_combined=[
 [101, 1, 2, 102, 10, 20, 102],
 [101, 3, 4, 102, 30, 40, 50, 60, 102],
 [101, 5, 6, 7, 8, 9, 102, 70, 80, 102],
]

# ids describing which items belong to which segment.
expected_ids=[
 [0, 0, 0, 0, 1, 1, 1],
 [0, 0, 0, 0, 1, 1, 1, 1, 1],
 [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]

segments A list of RaggedTensors with the tokens of the input segments. All elements must have the same dtype (int32 or int64), same rank, and same dimension 0 (namely batch size). Slice segments[i][j, ...] contains the tokens of the i-th input segment to the j-th example in the batch.
start_of_sequence_id a python int or scalar Tensor containing the id used to denote the start of a sequence (e.g. [CLS] token in BERT terminology).
end_of_segment_id a python int or scalar Tensor containing the id used to denote end of a segment (e.g. the [SEP] token in BERT terminology).

a tuple of (combined_segments, segment_ids), where:
combined_segments A RaggedTensor with segments combined and special tokens inserted.
segment_ids A RaggedTensor w/ the same shape as combined_segments and containing int ids for each item detailing the segment that they correspond to.