blob: bcfc63294f278e0b37820d853ee9d15794088fcc [file] [log] [blame]
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lsh."""
from absl.testing import absltest
import numpy as np
from clustering import lsh
class LshTest(absltest.TestCase):
def test_projection_vectors_shape(self):
dim, max_hash_len = 10, 6
sh = lsh.SimHash(dim, max_hash_len)
self.assertEqual(sh.projection_vectors.shape, (max_hash_len, dim))
def test_value_errors(self):
dim, max_hash_len = 10, 6
num_points = 50
sh = lsh.SimHash(dim, max_hash_len)
datapoints = np.random.normal(size=(num_points, dim))
with self.assertRaises(ValueError):
sh.group_by_next_hash(datapoints, hash_prefix="010010")
with self.assertRaises(ValueError):
sh.group_by_next_hash(datapoints, hash_prefix="0101011")
def test_group_by_next_hash_shape(self):
dim, max_hash_len = 10, 6
num_points = 50
sh = lsh.SimHash(dim, max_hash_len)
datapoints = np.random.normal(size=(num_points, dim))
children = sh.group_by_next_hash(datapoints)
self.assertEqual(children["0"].shape[0] + children["1"].shape[0],
num_points)
def test_group_by_next_hash(self):
dim, max_hash_len = 5, 2
hash_prefix = "0"
projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
datapoints = np.array([[1.5, 0, 1, 0.5, 0],
[1, 1, 0, 0.7, 0.1],
[0.8, 0.1, 1, 0.2, 0.4],
[0.1, 0.5, 0.3, 0.7, 0.8],
[-0.5, 0.1, -0.3, -0.4, 0.2]])
children = sh.group_by_next_hash(datapoints, hash_prefix)
self.assertTrue((children["0"] == datapoints[[0, 1]]).all())
self.assertTrue((children["1"] == datapoints[[2, 3, 4]]).all())
if __name__ == "__main__":
absltest.main()