View source on GitHub |
A Trimmer
that allocates a length budget to segments in order.
Inherits From: Trimmer
text.WaterfallTrimmer(
max_seq_length, axis=-1
)
A Trimmer
that allocates a length budget to segments in order. It selects
elements to drop, according to a max sequence length budget, and then applies
this mask to actually drop the elements. See generate_mask()
for more
details.
Example:
a = tf.ragged.constant([['a', 'b', 'c'], [], ['d']])
b = tf.ragged.constant([['1', '2', '3'], [], ['4', '5', '6', '7']])
trimmer = tf_text.WaterfallTrimmer(4)
trimmer.trim([a, b])
[<tf.RaggedTensor [[b'a', b'b', b'c'], [], [b'd']]>,
<tf.RaggedTensor [[b'1'], [], [b'4', b'5', b'6']]>]
Here, for the first pair of elements, ['a', 'b', 'c']
and ['1', '2', '3']
,
the '2'
and '3'
are dropped to fit the sequence within the max sequence
length budget.
Methods
generate_mask
generate_mask(
segments
)
Calculates a truncation mask given a per-batch budget.
Calculate a truncation mask given a budget of the max number of items for each or all batch row. The allocation of the budget is done using a 'waterfall' algorithm. This algorithm allocates quota in a left-to-right manner and fill up the buckets until we run out of budget.
For example if the budget of [5] and we have segments of size [3, 4, 2], the truncate budget will be allocated as [3, 2, 0].
The budget can be a scalar, in which case the same budget is broadcasted
and applied to all batch rows. It can also be a 1D Tensor
of size
batch_size
, in which each batch row i will have a budget corresponding to
per_batch_quota[i]
.
Example:
a = tf.ragged.constant([['a', 'b', 'c'], [], ['d']])
b = tf.ragged.constant([['1', '2', '3'], [], ['4', '5', '6', '7']])
trimmer = tf_text.WaterfallTrimmer(4)
trimmer.generate_mask([a, b])
[<tf.RaggedTensor [[True, True, True], [], [True]]>,
<tf.RaggedTensor [[True, False, False], [], [True, True, True, False]]>]
Args | |
---|---|
segments
|
A list of RaggedTensor each w/ a shape of [num_batch,
(num_items)].
|
Returns | |
---|---|
a list with len(segments) of RaggedTensor s, see superclass for details.
|
trim
trim(
segments
)
Truncate the list of segments
.
Truncate the list of segments
using the truncation strategy defined by
generate_mask
.
Args | |
---|---|
segments
|
A list of RaggedTensor s w/ shape [num_batch, (num_items)].
|
Returns | |
---|---|
a list of RaggedTensor s with len(segments) number of items and where
each item has the same shape as its counterpart in segments and
with unwanted values dropped. The values are dropped according to the
TruncationStrategy defined.
|