blob: 53e5ce0542945d0e358cb669e502bd33b65f82ce [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for approx_max_k and approx_min_k."""
import itertools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.eager import test
from tensorflow.python.eager.def_function import function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
class ApproxTopkTest(test.TestCase, parameterized.TestCase):
def setUp(self):
test.TestCase.setUp(self)
self._rng = np.random.default_rng(42)
def compute_recall(self, result_neighbors, ground_truth_neighbors):
"""Computes the recall of an approximate nearest neighbor search.
Args:
result_neighbors: int32 numpy array of the shape [num_queries,
neighbors_per_query] where the values are the indices of the dataset.
ground_truth_neighbors: int32 numpy array of with shape [num_queries,
ground_truth_neighbors_per_query] where the values are the indices of
the dataset.
Returns:
The recall.
"""
self.assertLen(result_neighbors.shape, 2)
self.assertLen(ground_truth_neighbors.shape, 2)
self.assertEqual(result_neighbors.shape[0], ground_truth_neighbors.shape[0])
gt_sets = [set(np.asarray(x)) for x in ground_truth_neighbors]
def hits_per_q(q, nn_per_q):
return len(list(x for x in nn_per_q if x.item() in gt_sets[q]))
hits = sum(
hits_per_q(q, nn_per_q) for q, nn_per_q in enumerate(result_neighbors))
return hits / ground_truth_neighbors.size
@parameterized.parameters(
itertools.product(
[1, 10], # k
[100, 500], # row_size
[1, 10, 128], # num_rows
[True, False], # aggregate_to_topk
))
def test_non_fused_max_k(self, k, row_size, num_rows, aggregate_to_topk):
row = np.arange(row_size, dtype=np.float32)
db = np.stack(list(self._rng.permutation(row) for _ in range(num_rows)))
@function(jit_compile=True)
def ann(db, k):
return nn_ops.approx_max_k(db, k, aggregate_to_topk=aggregate_to_topk)
with ops.device('/device:TPU:0'):
db_op = variables.Variable(db)
result = ann(db_op, k)[1]
gt = np.argsort(-db)[:, :k]
ann_recall = self.compute_recall(result.numpy(), gt)
self.assertGreaterEqual(ann_recall, 0.95)
@parameterized.parameters(
itertools.product(
[1, 10], # k
[100, 500], # row_size
[1, 10, 128], # num_rows
[True, False], # aggregate_to_topk
))
def test_non_fused_min_k(self, k, row_size, num_rows, aggregate_to_topk):
# Use the new rng api
row = np.arange(row_size, dtype=np.float32)
db = np.stack(list(self._rng.permutation(row) for _ in range(num_rows)))
@function(jit_compile=True)
def ann(db, k=10):
return nn_ops.approx_min_k(db, k, aggregate_to_topk=aggregate_to_topk)
with ops.device('/device:TPU:0'):
db_op = variables.Variable(db)
result = ann(db_op, k)[1]
gt = np.argsort(db)[:, :k]
ann_recall = self.compute_recall(result.numpy(), gt)
self.assertGreaterEqual(ann_recall, 0.95)
@parameterized.parameters(
itertools.product(
[1, 10], # k
[100, 500], # db_size
[1, 10, 128], # qy_size
[2, 32], # feature dim
))
# MIPS = Maximal Inner Product Search
def test_mips(self, k, db_size, qy_size, feature_dim):
qy = self._rng.random([qy_size, feature_dim], dtype=np.float32)
db = self._rng.random([db_size, feature_dim], dtype=np.float32)
@function(jit_compile=True)
def ann(qy, db, k):
scores = math_ops.matmul(qy, db, transpose_b=True)
return nn_ops.approx_max_k(scores, k)
with ops.device('/device:TPU:0'):
qy_op = variables.Variable(qy)
db_op = variables.Variable(db)
result = ann(qy_op, db_op, k)[1]
scores = -math_ops.matmul(qy_op, db_op, transpose_b=True)
gt = np.argsort(scores.numpy())[:, :k]
ann_recall = self.compute_recall(result.numpy(), gt)
self.assertGreaterEqual(ann_recall, 0.95)
@parameterized.parameters(
itertools.product(
[1, 10], # k
[100, 500], # db_size
[10, 128], # qy_size
[2, 8], # feature dim
))
# L2ANN = Approximate Nearest Neighbor search in the L2 metric space
def test_l2ann(self, k, db_size, qy_size, feature_dim):
qy = self._rng.random([qy_size, feature_dim], dtype=np.float32)
db = self._rng.random([db_size, feature_dim], dtype=np.float32)
db_half_norm_sq = np.linalg.norm(db, axis=1)**2 / 2
@function(jit_compile=True)
def ann(qy, db, db_half_norm_sq, k):
scores = db_half_norm_sq - math_ops.matmul(qy, db, transpose_b=True)
return nn_ops.approx_min_k(scores, k)
with ops.device('/device:TPU:0'):
qy_op = variables.Variable(qy)
db_op = variables.Variable(db)
db_half_norm_sq_op = variables.Variable(db_half_norm_sq)
result = ann(qy_op, db_op, db_half_norm_sq_op, k)[1]
scores = db_half_norm_sq_op - math_ops.matmul(
qy_op, db_op, transpose_b=True)
gt = np.argsort(scores.numpy())[:, :k]
ann_recall = self.compute_recall(result.numpy(), gt)
self.assertGreaterEqual(ann_recall, 0.95)
def test_highdim(self):
db = self._rng.random([2, 10, 200, 3], dtype=np.float32)
k = 5
@function(jit_compile=True)
def ann(db, k):
return nn_ops.approx_min_k(db, k=k, reduction_dimension=2)
with ops.device('/device:TPU:0'):
db_op = variables.Variable(db)
result = ann(db_op, k)[1]
gt = np.argsort(db, axis=2)[:, :, :k, :]
flat_idx = np.reshape(
np.transpose(result.numpy(), [0, 1, 3, 2]), [2 * 10 * 3, k])
flat_gt = np.reshape(np.transpose(gt, [0, 1, 3, 2]), [2 * 10 * 3, k])
ann_recall = self.compute_recall(flat_idx, flat_gt)
self.assertGreaterEqual(ann_recall, 0.95)
@parameterized.parameters(
itertools.product(
[dtypes.bfloat16, dtypes.float16, dtypes.float32],
[1, 10], # k
[100, 500], # row_size
[1, 10, 128], # num_rows
))
def test_gradients(self, dtype, k, row_size, num_rows):
row = np.arange(row_size, dtype=np.float32)
db = np.stack(list(self._rng.permutation(row) for _ in range(num_rows)))
out_grads = self._rng.random([num_rows, k])
@function(jit_compile=True)
def ann_with_grads(db, out_grads):
with backprop.GradientTape() as tape:
val, idx = nn_ops.approx_max_k(db, k)
result_in_grads = tape.gradient(val, db, out_grads)
lifted_k_idx = array_ops.reshape(idx, [num_rows, k, 1])
iota_idx = array_ops.broadcast_to(
array_ops.reshape(math_ops.range(num_rows), [num_rows, 1, 1]),
[num_rows, k, 1])
lifted_idx = array_ops.concat([iota_idx, lifted_k_idx], axis=2)
k_idx_s = array_ops.reshape(lifted_idx, [num_rows * k, 2])
k_gra_s = array_ops.reshape(out_grads, [num_rows * k])
expected_in_grads = array_ops.scatter_nd(k_idx_s, k_gra_s,
[num_rows, row_size])
return [expected_in_grads, result_in_grads]
with ops.device('/device:TPU:0'):
db_op = variables.Variable(db, dtype=dtype)
out_grads_op = variables.Variable(out_grads, dtype=dtype)
expected_in_grads, result_in_grads = ann_with_grads(db_op, out_grads_op)
self.assertAllClose(expected_in_grads, result_in_grads)
# Tests that multiple ops are supported and the comparison functions are
# renamed properly to avoid conflict while using the MLIR bridge.
def test_multiple_ops(self):
k = 1
row_size = 100
num_rows = 10
row = np.arange(row_size, dtype=np.float32)
db1 = np.stack(list(self._rng.permutation(row) for _ in range(num_rows)))
db2 = np.stack(list(self._rng.permutation(row) for _ in range(num_rows)))
@function(jit_compile=True)
def ann(db1, db2):
result1 = nn_ops.approx_max_k(db1, k, aggregate_to_topk=True)
result2 = nn_ops.approx_max_k(db2, k, aggregate_to_topk=True)
return (result1, result2)
with ops.device('/device:TPU:0'):
db1_op = variables.Variable(db1)
db2_op = variables.Variable(db2)
result1, result2 = ann(db1_op, db2_op)
gt = np.argsort(-db1)[:, :k]
ann_recall = self.compute_recall(result1[1].numpy(), gt)
self.assertGreaterEqual(ann_recall, 0.95)
gt = np.argsort(-db2)[:, :k]
ann_recall = self.compute_recall(result2[1].numpy(), gt)
self.assertGreaterEqual(ann_recall, 0.95)
if __name__ == '__main__':
test.main()