View source on GitHub |
Returns min k
values and their indices of the input operand
in an approximate manner.
tf.math.approx_min_k(
operand,
k,
reduction_dimension=-1,
recall_target=0.95,
reduction_input_size_override=-1,
aggregate_to_topk=True,
name=None
)
See https://arxiv.org/abs/2206.14286 for the algorithm details. This op is only optimized on TPU currently.
Returns | |
---|---|
Tuple of two arrays. The arrays are the least k values and the
corresponding indices along the reduction_dimension of the input
operand . The arrays' dimensions are the same as the input operand
except for the reduction_dimension : when aggregate_to_topk is true,
the reduction dimension is k ; otherwise, it is greater equals to k
where the size is implementation-defined.
|
We encourage users to wrap approx_min_k
with jit. See the following example
for nearest neighbor search over the squared l2 distance:
import tensorflow as tf
@tf.function(jit_compile=True)
def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
dists = half_db_norms - tf.einsum('ik,jk->ij', qy, db)
return tf.nn.approx_min_k(dists, k=k, recall_target=recall_target)
qy = tf.random.uniform((256,128))
db = tf.random.uniform((2048,128))
half_db_norms = tf.norm(db, axis=1) / 2
dists, neighbors = l2_ann(qy, db, half_db_norms)
In the example above, we compute db_norms/2 - dot(qy, db^T)
instead of
qy^2 - 2 dot(qy, db^T) + db^2
for performance reason. The former uses less
arithmetics and produces the same set of neighbors.