Generates the segmentation corresponding to a RaggedTensor row_splits
.
tf.ragged.row_splits_to_segment_ids(
splits, name=None, out_type=None
)
Returns an integer vector segment_ids
, where segment_ids[i] == j
if
splits[j] <= i < splits[j+1]
. Example:
print(tf.ragged.row_splits_to_segment_ids([0, 3, 3, 5, 6, 9]))
tf.Tensor([0 0 0 2 2 3 4 4 4], shape=(9,), dtype=int64)
Args |
splits
|
A sorted 1-D integer Tensor. splits[0] must be zero.
|
name
|
A name prefix for the returned tensor (optional).
|
out_type
|
The dtype for the return value. Defaults to splits.dtype ,
or tf.int64 if splits does not have a dtype.
|
Returns |
A sorted 1-D integer Tensor, with shape=[splits[-1]]
|
Raises |
ValueError
|
If splits is invalid.
|