text.ShrinkLongestTrimmer

A Trimmer that truncates the longest segment.

Inherits From: Trimmer

A Trimmer that allocates a length budget to segments by shrinking whatever is the longest segment at each round at the end, until the total length of segments is no larger than the allocated budget. See generate_mask() for more details.

Methods

generate_mask

View source

Calculates a truncation mask given a per-batch budget.

Calculate a truncation mask given a budget of the max number of items for each batch row. The allocation of the budget is done using a 'shrink the largest segment' algorithm. This algorithm identifies the currently longest segment (in cases of tie, picking whichever segment occurs first) and reduces its length by 1 by dropping its last element, repeating until the total length of segments is no larger than _max_seq_length.

For example if the budget is [7] and we have segments of size [3, 4, 4], the truncate budget will be allocated as [2, 2, 3], going through truncation steps # Truncate the second segment. [3, 3, 4] # Truncate the last segment. [3, 3, 3] # Truncate the first segment. [2, 3, 3] # Truncate the second segment. [2, 2, 3]

Args
segments A list of RaggedTensor each w/ a shape of [num_batch, (num_items)].

Returns
a list with len(segments) of RaggedTensors, see superclass for details.

trim

View source

Truncate the list of segments.

Truncate the list of segments using the truncation strategy defined by generate_mask.

Args
segments A list of RaggedTensors w/ shape [num_batch, (num_items)].

Returns
a list of RaggedTensors with len(segments) number of items and where each item has the same shape as its counterpart in segments and with unwanted values dropped. The values are dropped according to the TruncationStrategy defined.