View source on GitHub |
A Trimmer
that truncates the longest segment.
Inherits From: Trimmer
text.ShrinkLongestTrimmer(
max_seq_length, axis=-1
)
A Trimmer
that allocates a length budget to segments by shrinking whatever
is the longest segment at each round at the end, until the total length of
segments is no larger than the allocated budget.
See generate_mask()
for more details.
Methods
generate_mask
generate_mask(
segments
)
Calculates a truncation mask given a per-batch budget.
Calculate a truncation mask given a budget of the max number of items for
each batch row. The allocation of the budget is done using a
'shrink the largest segment' algorithm. This algorithm identifies the
currently longest segment (in cases of tie, picking whichever segment occurs
first) and reduces its length by 1 by dropping its last element, repeating
until the total length of segments is no larger than _max_seq_length
.
For example if the budget is [7] and we have segments of size [3, 4, 4], the truncate budget will be allocated as [2, 2, 3], going through truncation steps # Truncate the second segment. [3, 3, 4] # Truncate the last segment. [3, 3, 3] # Truncate the first segment. [2, 3, 3] # Truncate the second segment. [2, 2, 3]
Args | |
---|---|
segments
|
A list of RaggedTensor each w/ a shape of [num_batch,
(num_items)].
|
Returns | |
---|---|
a list with len(segments) of RaggedTensor s, see superclass for details.
|
trim
trim(
segments
)
Truncate the list of segments
.
Truncate the list of segments
using the truncation strategy defined by
generate_mask
.
Args | |
---|---|
segments
|
A list of RaggedTensor s w/ shape [num_batch, (num_items)].
|
Returns | |
---|---|
a list of RaggedTensor s with len(segments) number of items and where
each item has the same shape as its counterpart in segments and
with unwanted values dropped. The values are dropped according to the
TruncationStrategy defined.
|