blob: 1fc7142fc58c9e1616961006b1fa4d1703941e34 [file] [log] [blame]
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lsh_tree."""
import typing
from unittest import mock
from absl.testing import absltest
import numpy as np
from scipy import stats
from clustering import clustering_params
from clustering import lsh
from clustering import lsh_tree
from clustering import test_utils
def get_test_origin_points(nonprivate_count=5, dim=10):
"""Points with defaults for parameters not needed for the test."""
return np.zeros((nonprivate_count, dim))
def get_test_sim_hash(dim=10, max_hash_len=1):
"""SimHash with defaults for parameters not needed for the test."""
return lsh.SimHash(dim=dim, max_hash_len=max_hash_len)
class TestLshTreeNode(lsh_tree.LshTreeNode):
"""Test node for testing without the real children hashing logic.
This test implementation always returns a fraction of the points with the next
hash_prefix containing 0, and the rest with 1, setting the private count equal
to the real count + 1. Equality is checked by just checking the hash_prefix
and private_count.
"""
# Fraction of the nonprivate_points to add a 0 to the hash_prefix.
frac_zero: float
def __init__(self, *args, frac_zero: float = 0.5, **kwargs):
"""Initializes test node.
Args:
*args: extra arguments for the parent class
frac_zero: fraction of the nonprivate_points to add a 0 to the hash_prefix
**kwargs: extra arguments for the parent class
"""
super(TestLshTreeNode, self).__init__(*args, **kwargs)
self.frac_zero = frac_zero
def get_private_count(self) -> int:
"""Returns a fake private count."""
self.private_count = len(self.nonprivate_points) + 1
return self.private_count
def children(self) -> typing.List[lsh_tree.LshTreeNode]:
"""Returns fake children for this node."""
cutoff = int(len(self.nonprivate_points) * self.frac_zero)
return [
TestLshTreeNode(
self.hash_prefix + '0',
self.nonprivate_points[:cutoff],
self.coreset_param,
self.sim_hash,
private_count=cutoff + 1,
frac_zero=self.frac_zero),
TestLshTreeNode(
self.hash_prefix + '1',
self.nonprivate_points[cutoff:],
self.coreset_param,
self.sim_hash,
private_count=len(self.nonprivate_points) - cutoff + 1,
frac_zero=self.frac_zero)
]
def __eq__(self, other):
"""Returns whether hash_prefix and private_count are the same for tests."""
if not isinstance(other, TestLshTreeNode):
return False
return self.hash_prefix == other.hash_prefix and (self.private_count
== other.private_count)
class LshTreeTest(absltest.TestCase):
@mock.patch.object(stats.dlaplace, 'rvs', return_value=-5, autospec=True)
def test_get_private_count_basic(self, mock_dlaplace_fn):
nonprivate_count = 30
nonprivate_points = get_test_origin_points(
nonprivate_count=nonprivate_count)
coreset_param = test_utils.get_test_coreset_param(
epsilon=5, frac_sum=0.2, frac_group_count=0.8, max_depth=9)
sim_hash = get_test_sim_hash()
lsh_tree_node = lsh_tree.LshTreeNode(
hash_prefix='',
nonprivate_points=nonprivate_points,
coreset_param=coreset_param,
sim_hash=sim_hash)
self.assertEqual(lsh_tree_node.get_private_count(), 25)
mock_dlaplace_fn.assert_called_once_with(0.4)
def test_get_private_count_cache(self):
nonprivate_count = 30
nonprivate_points = get_test_origin_points(
nonprivate_count=nonprivate_count)
coreset_param = test_utils.get_test_coreset_param(epsilon=0.01)
sim_hash = get_test_sim_hash()
lsh_tree_node = lsh_tree.LshTreeNode(
hash_prefix='',
nonprivate_points=nonprivate_points,
coreset_param=coreset_param,
sim_hash=sim_hash)
first_private_count = lsh_tree_node.get_private_count()
self.assertEqual(first_private_count, lsh_tree_node.get_private_count())
def test_get_children(self):
hash_prefix, dim, max_hash_len = '0', 5, 2
datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1],
[0.8, 0.1, 1, 0.2, 0.4], [0.1, 0.5, 0.3, 0.7, 0.8],
[-0.5, 0.1, -0.3, -0.4, 0.2]])
# Returns children regardless of whether the node should branch. The
# filtering in the algorithm is done after.
coreset_param = test_utils.get_test_coreset_param(max_depth=max_hash_len)
projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
node = lsh_tree.LshTreeNode(hash_prefix, datapoints, coreset_param, sh)
children = node.children()
self.assertSameElements([child.hash_prefix for child in children],
['00', '01'])
for child in children:
self.assertEqual(child.coreset_param, coreset_param)
self.assertEqual(child.sim_hash, sh)
if child.hash_prefix == '00':
self.assertTrue((child.nonprivate_points == datapoints[[0, 1]]).all())
if child.hash_prefix == '01':
self.assertTrue((child.nonprivate_points == datapoints[[2, 3,
4]]).all())
def test_get_children_one_empty(self):
hash_prefix, dim, max_hash_len = '0', 5, 2
datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1]])
# Returns children regardless of whether the node should branch. The
# filtering in the algorithm is done after.
coreset_param = test_utils.get_test_coreset_param(max_depth=max_hash_len)
projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
node = lsh_tree.LshTreeNode(hash_prefix, datapoints, coreset_param, sh)
children = node.children()
self.assertSameElements([child.hash_prefix for child in children],
['00', '01'])
for child in children:
self.assertEqual(child.coreset_param, coreset_param)
self.assertEqual(child.sim_hash, sh)
if child.hash_prefix == '00':
self.assertTrue((child.nonprivate_points == datapoints[[0, 1]]).all())
if child.hash_prefix == '01':
self.assertEmpty(child.nonprivate_points)
def test_get_children_error(self):
hash_prefix, dim, max_hash_len = '00', 5, 2
datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1]])
# Returns children regardless of whether the node should branch. The
# filtering in the algorithm is done after.
coreset_param = test_utils.get_test_coreset_param(max_depth=max_hash_len)
projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
node = lsh_tree.LshTreeNode(hash_prefix, datapoints, coreset_param, sh)
with self.assertRaises(ValueError):
node.children()
def test_filter_branching_nodes_too_few_points(self):
sim_hash = get_test_sim_hash()
# private_count, not the nonprivate_count, should be used for the check.
level: lsh_tree.LshTreeLevel = [
lsh_tree.LshTreeNode(
'0',
get_test_origin_points(nonprivate_count=15),
test_utils.get_test_coreset_param(
min_num_points_in_branching_node=10),
sim_hash,
private_count=1),
]
self.assertEmpty(lsh_tree.LshTree.filter_branching_nodes(level))
def test_filter_branching_nodes_enough_points(self):
sim_hash = get_test_sim_hash()
level: lsh_tree.LshTreeLevel = [
lsh_tree.LshTreeNode(
'0',
get_test_origin_points(nonprivate_count=15),
test_utils.get_test_coreset_param(
min_num_points_in_branching_node=10),
sim_hash,
private_count=20),
]
self.assertSequenceEqual(
lsh_tree.LshTree.filter_branching_nodes(level), level)
def test_get_next_level_empty_list(self):
self.assertEmpty(lsh_tree.LshTree.get_next_level([]))
def test_get_next_level(self):
sim_hash = get_test_sim_hash()
coreset_param = test_utils.get_test_coreset_param(
min_num_points_in_branching_node=10, min_num_points_in_node=5)
level: lsh_tree.LshTreeLevel = [
TestLshTreeNode(
'0',
get_test_origin_points(nonprivate_count=16),
coreset_param,
sim_hash,
private_count=20),
]
expected_next_level = [
TestLshTreeNode(
'00',
get_test_origin_points(nonprivate_count=8),
coreset_param,
sim_hash,
private_count=9),
TestLshTreeNode(
'01',
get_test_origin_points(nonprivate_count=8),
coreset_param,
sim_hash,
private_count=9),
]
branching_nodes = lsh_tree.LshTree.filter_branching_nodes(level)
self.assertSequenceEqual(
lsh_tree.LshTree.get_next_level(branching_nodes), expected_next_level)
def test_get_next_level_filters_children_node(self):
sim_hash = get_test_sim_hash()
coreset_param = test_utils.get_test_coreset_param(
min_num_points_in_branching_node=10, min_num_points_in_node=9)
level: lsh_tree.LshTreeLevel = [
# The children test nodes have a private count of 6, which is less than
# min_num_points_in_node.
TestLshTreeNode(
'0',
get_test_origin_points(nonprivate_count=10),
coreset_param,
sim_hash,
private_count=11),
# The children test nodes have a private count of 3 and 9, only the node
# with 9 should be in the result.
TestLshTreeNode(
'1',
get_test_origin_points(nonprivate_count=10),
coreset_param,
sim_hash,
private_count=11,
frac_zero=0.2),
]
expected_next_level = [
TestLshTreeNode(
'11',
get_test_origin_points(nonprivate_count=8),
coreset_param,
sim_hash,
private_count=9),
]
branching_nodes = lsh_tree.LshTree.filter_branching_nodes(level)
self.assertSequenceEqual(
lsh_tree.LshTree.get_next_level(branching_nodes), expected_next_level)
def test_root_node(self):
nonprivate_points = [[1, 2, 1], [0.4, 0.2, 0.8], [3, 0, 3]]
data = clustering_params.Data(nonprivate_points, radius=4.3)
coreset_param = test_utils.get_test_coreset_param(radius=4.3, max_depth=20)
root = lsh_tree.root_node(data, coreset_param)
self.assertEqual(root.hash_prefix, '')
self.assertSequenceEqual(root.nonprivate_points, nonprivate_points)
self.assertEqual(root.coreset_param, coreset_param)
self.assertEqual(root.sim_hash.dim, 3)
self.assertEqual(root.sim_hash.max_hash_len, 20)
self.assertIsNotNone(root.private_count)
def test_root_node_provide_private_count(self):
nonprivate_points = [[1, 2, 1], [0.4, 0.2, 0.8], [3, 0, 3]]
data = clustering_params.Data(nonprivate_points, radius=4.3)
coreset_param = test_utils.get_test_coreset_param(radius=4.3, max_depth=20)
root = lsh_tree.root_node(data, coreset_param, private_count=10)
self.assertEqual(root.hash_prefix, '')
self.assertSequenceEqual(root.nonprivate_points, nonprivate_points)
self.assertEqual(root.coreset_param, coreset_param)
self.assertEqual(root.sim_hash.dim, 3)
self.assertEqual(root.sim_hash.max_hash_len, 20)
self.assertEqual(root.private_count, 10)
def test_lsh_tree_empty_root_errors(self):
test_root = lsh_tree.LshTreeNode(
'0',
get_test_origin_points(nonprivate_count=15),
test_utils.get_test_coreset_param(),
get_test_sim_hash(),
private_count=0)
with self.assertRaises(ValueError):
lsh_tree.LshTree(test_root)
def test_lsh_tree_negative_count_root_errors(self):
test_root = lsh_tree.LshTreeNode(
'0',
get_test_origin_points(nonprivate_count=15),
test_utils.get_test_coreset_param(),
get_test_sim_hash(),
private_count=-10)
with self.assertRaises(ValueError):
lsh_tree.LshTree(test_root)
def test_lsh_tree(self):
# Test tree:
# Nodes are nonprivate count + 1.
# Branches to the left are 0, to the right are 1.
# Nodes in parentheses are filtered out.
# 64+1
# / \
# 8+1 56+1
# / \ / \
# (1+1) 7+1 7+1 49+1
# / \
# (6+1) 43+1
nonprivate_count = 64
sh = get_test_sim_hash()
cp = test_utils.get_test_coreset_param(
min_num_points_in_node=8,
min_num_points_in_branching_node=9,
max_depth=3)
test_root = TestLshTreeNode(
'', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.125)
expected_tree = {
0: [TestLshTreeNode('', get_test_origin_points(64), cp, sh)],
1: [
TestLshTreeNode('0', get_test_origin_points(8), cp, sh),
TestLshTreeNode('1', get_test_origin_points(56), cp, sh)
],
2: [
TestLshTreeNode('01', get_test_origin_points(7), cp, sh),
TestLshTreeNode('10', get_test_origin_points(7), cp, sh),
TestLshTreeNode('11', get_test_origin_points(49), cp, sh)
],
3: [TestLshTreeNode('111', get_test_origin_points(43), cp, sh)],
}
tree = lsh_tree.LshTree(test_root)
self.assertEqual(tree.tree, expected_tree)
def test_lsh_tree_branching_node_becomes_leaf(self):
# Test tree:
# Nodes are nonprivate count + 1.
# Branches to the left are 0, to the right are 1.
# Nodes in parentheses are filtered out.
# 64+1
# / \
# 32+1 32+1
# / \ / \
# (16+1) (16+1) (16+1) (16+1)
nonprivate_count = 64
sh = get_test_sim_hash()
cp = test_utils.get_test_coreset_param(
min_num_points_in_node=20,
min_num_points_in_branching_node=30,
max_depth=5)
test_root = TestLshTreeNode(
'', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.5)
expected_tree = {
0: [TestLshTreeNode('', get_test_origin_points(64), cp, sh)],
1: [
TestLshTreeNode('0', get_test_origin_points(32), cp, sh),
TestLshTreeNode('1', get_test_origin_points(32), cp, sh),
]
}
tree = lsh_tree.LshTree(test_root)
self.assertEqual(tree.tree, expected_tree)
def test_lsh_tree_leaves(self):
# Test tree:
# Nodes are nonprivate count + 1.
# Branches to the left are 0, to the right are 1.
# Nodes in parentheses are filtered out.
# 64+1
# / \
# 8+1 56+1
# / \ / \
# (1+1) 7+1 7+1 49+1
# / \
# (6+1) 43+1
nonprivate_count = 64
sh = get_test_sim_hash()
cp = test_utils.get_test_coreset_param(
min_num_points_in_node=8,
min_num_points_in_branching_node=9,
max_depth=3)
test_root = TestLshTreeNode(
'', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.125)
expected_leaves = [
TestLshTreeNode('01', get_test_origin_points(7), cp, sh),
TestLshTreeNode('10', get_test_origin_points(7), cp, sh),
TestLshTreeNode('111', get_test_origin_points(43), cp, sh)
]
tree = lsh_tree.LshTree(test_root)
self.assertEqual(tree.leaves, expected_leaves)
def test_lsh_tree_leaves_branching_node_becomes_leaf(self):
# Test tree:
# Nodes are nonprivate count + 1.
# Branches to the left are 0, to the right are 1.
# Nodes in parentheses are filtered out.
# 64+1
# / \
# 32+1 32+1
# / \ / \
# (16+1) (16+1) (16+1) (16+1)
nonprivate_count = 64
sh = get_test_sim_hash()
cp = test_utils.get_test_coreset_param(
min_num_points_in_node=20,
min_num_points_in_branching_node=30,
max_depth=5)
test_root = TestLshTreeNode(
'', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.5)
expected_leaves = [
TestLshTreeNode('0', get_test_origin_points(32), cp, sh),
TestLshTreeNode('1', get_test_origin_points(32), cp, sh),
]
tree = lsh_tree.LshTree(test_root)
self.assertEqual(tree.leaves, expected_leaves)
if __name__ == '__main__':
absltest.main()